src

Ψπ (psipy): Symbolic-Numerical Toolkit for PDEs and Hamiltonian Mechanics

Overview

Welcome to psipy, a comprehensive Python ecosystem designed to bridge the gap between formal symbolic mathematics (via SymPy) and high-performance numerical simulation (via NumPy/SciPy). This library provides a unified framework for defining, analyzing, solving, and visualizing complex problems in:

  • Partial Differential Equations (PDEs)
  • Pseudo-Differential Operators (ΨDOs)
  • Hamiltonian and Lagrangian Mechanics
  • Semiclassical and Microlocal Analysis

The core philosophy is to allow users to move seamlessly from a formal symbolic definition—such as a Lagrangian, a Hamiltonian from the included catalog, or a PDE written in SymPy—to a robust numerical analysis, such as solving the PDE's evolution, visualizing its phase-space geometry, or computing its semiclassical spectrum.

Core Components

The psipy ecosystem is composed of several powerful, interoperable modules:

  • PDESolver: The main numerical engine. It parses symbolic PDEs and solves 1D/2D, linear/nonlinear, time-dependent or stationary equations. It uses spectral (FFT) methods with high-order exponential integrators (like ETD-RK4) for robust time evolution.

  • PseudoDifferentialOperator: A complete symbolic and numerical framework for Pseudo-Differential Operators (ΨDOs). It supports symbolic calculus (composition, commutators, adjoints) and microlocal analysis (ellipticity, characteristic sets), bridging formal definitions with numerical evaluation on grids.

  • LagrangianHamiltonianConverter & HamiltonianSymbolicConverter: A symbolic toolkit for analytical mechanics. It performs purely symbolic Legendre transforms (L ↔ H) and can automatically generate formal symbolic PDEs (e.g., Schrödinger, Wave) from any given Hamiltonian symbol.

  • HamiltonianCatalog: A vast, curated, and searchable symbolic database of over 500 Hamiltonian systems. It spans classical mechanics, quantum chaos, biophysics, and more, providing a rich testbed for research and education.

  • SymbolGeometry: A comprehensive analysis and visualization suite for 1D Hamiltonian systems. It connects classical geometry to quantum spectra by computing classical trajectories, periodic orbits, and the semiclassical energy spectrum via the Gutzwiller trace formula and EBK quantization.

  • SymbolGeometry2D: An advanced 2D analysis toolkit for visualizing dynamical systems. It performs rigorous caustic detection by tracking the full 4x4 Jacobian, generates Poincaré sections, and analyzes KAM tori, providing a deep dive into 2D phase space geometry.

Typical Workflow

A common use case involves combining all modules:

  1. Select a System: Fetch a complex Hamiltonian (e.g., "henon_heiles") from the HamiltonianCatalog.

  2. Formulate the PDE: Use SymPhysics to automatically generate the corresponding symbolic Schrödinger equation.

  3. Analyze Geometry: Pass the Hamiltonian symbol to SymbolGeometry2D to visualize its classical trajectories, Poincaré sections, and chaotic regions.

  4. Solve Dynamics: Pass the symbolic PDE to the PDESolver to simulate the quantum wave function's evolution in time.

Example: Solving a Pseudo-Differential PDE

This example defines a 1D Schrödinger-type equation with a non-local, relativistic kinetic term, i ∂ₜ u = √(1 - ∂ₓ²) u.

from solver import *

# 1. Define symbolic variables
t, x, xi = symbols('t x xi', real=True)
u = Function('u')

# 2. Define the PDE symbolically
# The symbol for the operator √(1 - ∂ₓ²) is p(ξ) = √(1 + ξ²)
# (using the Fourier convention p(ξ) → op(ξ) → -∂ₓ²)
p_symbol = (1 + xi**2)**(1/2)

# The equation is: i * ∂ₜ u = psiOp(p_symbol) * u
equation = Eq(I * diff(u(t, x), t), psiOp(p_symbol, u(t, x)))

# 3. Create the solver
solver = PDESolver(equation)

# 4. Setup the simulation domain and initial condition
initial_packet = lambda x: np.exp(-(x - np.pi)**2 / 0.5) * np.exp(1j * 5.0 * x)
solver.setup(
    Lx=2 * np.pi, Nx=256,
    Lt=4.0, Nt=1000,
    initial_condition=initial_packet,
    boundary_condition='periodic'
)

# 5. Solve the PDE
solver.solve()

# 6. Animate the solution
ani = solver.animate()
HTML(ani.to_jshtml())
  1"""
  2Ψπ (psipy): Symbolic-Numerical Toolkit for PDEs and Hamiltonian Mechanics
  3========================================================================
  4
  5## Overview
  6
  7Welcome to `psipy`, a comprehensive Python ecosystem designed to bridge the gap
  8between formal symbolic mathematics (via SymPy) and high-performance numerical
  9simulation (via NumPy/SciPy). This library provides a unified framework for
 10defining, analyzing, solving, and visualizing complex problems in:
 11
 12- Partial Differential Equations (PDEs)
 13- Pseudo-Differential Operators (ΨDOs)
 14- Hamiltonian and Lagrangian Mechanics
 15- Semiclassical and Microlocal Analysis
 16
 17The core philosophy is to allow users to move seamlessly from a formal symbolic
 18definition—such as a Lagrangian, a Hamiltonian from the included catalog, or a
 19PDE written in SymPy—to a robust numerical analysis, such as solving the PDE's
 20evolution, visualizing its phase-space geometry, or computing its semiclassical
 21spectrum.
 22
 23## Core Components
 24
 25The `psipy` ecosystem is composed of several powerful, interoperable modules:
 26
 27- **`PDESolver`**: The main numerical engine. It parses symbolic PDEs and solves
 28  1D/2D, linear/nonlinear, time-dependent or stationary equations. It uses spectral
 29  (FFT) methods with high-order exponential integrators (like ETD-RK4) for robust
 30  time evolution.
 31
 32- **`PseudoDifferentialOperator`**: A complete symbolic and numerical framework for Pseudo-Differential
 33  Operators (ΨDOs). It supports symbolic calculus (composition, commutators, adjoints)
 34  and microlocal analysis (ellipticity, characteristic sets), bridging formal definitions
 35  with numerical evaluation on grids.
 36
 37- **`LagrangianHamiltonianConverter` & `HamiltonianSymbolicConverter`**: A symbolic toolkit for analytical mechanics. It performs purely
 38  symbolic Legendre transforms (L ↔ H) and can automatically generate formal symbolic
 39  PDEs (e.g., Schrödinger, Wave) from any given Hamiltonian symbol.
 40
 41- **`HamiltonianCatalog`**: A vast, curated, and searchable symbolic database of
 42  **over 500** Hamiltonian systems. It spans classical mechanics, quantum chaos,
 43  biophysics, and more, providing a rich testbed for research and education.
 44
 45- **`SymbolGeometry`**: A comprehensive analysis and visualization suite for 1D
 46  Hamiltonian systems. It connects classical geometry to quantum spectra by computing
 47  classical trajectories, periodic orbits, and the semiclassical energy spectrum via
 48  the **Gutzwiller trace formula** and **EBK quantization**.
 49
 50- **`SymbolGeometry2D`**: An advanced 2D analysis toolkit for visualizing dynamical
 51  systems. It performs rigorous **caustic detection** by tracking the full 4x4 Jacobian,
 52  generates **Poincaré sections**, and analyzes **KAM tori**, providing a deep dive
 53  into 2D phase space geometry.
 54
 55## Typical Workflow
 56
 57A common use case involves combining all modules:
 58
 591. **Select a System**: Fetch a complex Hamiltonian (e.g., "henon_heiles")
 60   from the `HamiltonianCatalog`.
 61
 622. **Formulate the PDE**: Use `SymPhysics` to automatically generate the
 63   corresponding symbolic Schrödinger equation.
 64
 653. **Analyze Geometry**: Pass the Hamiltonian symbol to `SymbolGeometry2D`
 66   to visualize its classical trajectories, Poincaré sections, and chaotic regions.
 67
 684. **Solve Dynamics**: Pass the symbolic PDE to the `PDESolver` to
 69   simulate the quantum wave function's evolution in time.
 70
 71## Example: Solving a Pseudo-Differential PDE
 72
 73This example defines a 1D Schrödinger-type equation with a non-local,
 74relativistic kinetic term, i ∂ₜ u = √(1 - ∂ₓ²) u.
 75
 76```python
 77from solver import *
 78
 79# 1. Define symbolic variables
 80t, x, xi = symbols('t x xi', real=True)
 81u = Function('u')
 82
 83# 2. Define the PDE symbolically
 84# The symbol for the operator √(1 - ∂ₓ²) is p(ξ) = √(1 + ξ²)
 85# (using the Fourier convention p(ξ) → op(ξ) → -∂ₓ²)
 86p_symbol = (1 + xi**2)**(1/2)
 87
 88# The equation is: i * ∂ₜ u = psiOp(p_symbol) * u
 89equation = Eq(I * diff(u(t, x), t), psiOp(p_symbol, u(t, x)))
 90
 91# 3. Create the solver
 92solver = PDESolver(equation)
 93
 94# 4. Setup the simulation domain and initial condition
 95initial_packet = lambda x: np.exp(-(x - np.pi)**2 / 0.5) * np.exp(1j * 5.0 * x)
 96solver.setup(
 97    Lx=2 * np.pi, Nx=256,
 98    Lt=4.0, Nt=1000,
 99    initial_condition=initial_packet,
100    boundary_condition='periodic'
101)
102
103# 5. Solve the PDE
104solver.solve()
105
106# 6. Animate the solution
107ani = solver.animate()
108HTML(ani.to_jshtml())
109```
110"""
111from importlib.metadata import version
112
113# Imports publics
114from .psiop import *
115from .solver import *
116from .physics import *
117from .geometry_1d import *
118from .geometry_2d import *
119from .hamiltonian_catalog import *
120
121# Version du package
122__version__ = version("psipy")
123
124# Liste des noms exposés par `from psipy import *`
125__all__ = [
126    "PseudoDifferentialOperator",
127    "PDESolver",
128    "LagrangianHamiltonianConverter",
129    "HamiltonianSymbolicConverter",
130    "SymbolGeometry",
131    "SymbolVisualizer",
132    "SpectralAnalysis",
133    "SymbolGeometry2D",
134    "SymbolVisualizer2D",
135    "Utilities2D",
136]
class PseudoDifferentialOperator:
  25class PseudoDifferentialOperator:
  26    """
  27    Pseudo-differential operator with dynamic symbol evaluation on spatial grids.
  28    Supports both 1D and 2D operators, and can be defined explicitly (symbol mode)
  29    or extracted automatically from symbolic equations (auto mode).
  30
  31    Parameters
  32    ----------
  33    expr : sympy expression
  34        Symbolic expression representing the pseudo-differential symbol.
  35    vars_x : list of sympy symbols
  36        Spatial variables (e.g., [x] for 1D, [x, y] for 2D).
  37    var_u : sympy function, optional
  38        Function u(x, t) used in auto mode to extract the operator symbol.
  39    mode : str, {'symbol', 'auto'}
  40        - 'symbol': directly uses expr as the operator symbol.
  41        - 'auto': computes the symbol automatically by applying expr to exp(i x ξ).
  42
  43    Attributes
  44    ----------
  45    dim : int
  46        Spatial dimension (1 or 2).
  47    fft, ifft : callable
  48        Fast Fourier transform and inverse (scipy.fft or scipy.fft2).
  49    p_func : callable
  50        Evaluated symbol function ready for numerical use.
  51
  52    Notes
  53    -----
  54    - In 'symbol' mode, `expr` should be expressed in terms of spatial variables and frequency variables (ξ, η).
  55    - In 'auto' mode, the symbol is derived by applying the differential expression to a complex exponential.
  56    - Frequency variables are internally named 'xi' and 'eta' for consistency.
  57    - Uses numpy for numerical evaluation and scipy.fft for FFT operations.
  58
  59    Examples
  60    --------
  61    >>> # Example 1: 1D Laplacian operator (symbol mode)
  62    >>> from sympy import symbols
  63    >>> x, xi = symbols('x xi', real=True)
  64    >>> op = PseudoDifferentialOperator(expr=xi**2, vars_x=[x], mode='symbol')
  65
  66    >>> # Example 2: 1D transport operator (auto mode)
  67    >>> from sympy import Function
  68    >>> u = Function('u')
  69    >>> expr = u(x).diff(x)
  70    >>> op = PseudoDifferentialOperator(expr=expr, vars_x=[x], var_u=u(x), mode='auto')
  71    """
  72
  73    def __init__(self, expr, vars_x, var_u=None, mode='symbol'):
  74        self.dim = len(vars_x)
  75        self.mode = mode
  76        self.symbol_cached = None
  77        self.expr = expr
  78        self.vars_x = vars_x
  79
  80        if self.dim == 1:
  81            x, = vars_x
  82            xi_internal = symbols('xi', real=True)
  83            expr = expr.subs(symbols('xi', real=True), xi_internal)
  84            self.fft = partial(fft, workers=FFT_WORKERS)
  85            self.ifft = partial(ifft, workers=FFT_WORKERS)
  86
  87            if mode == 'symbol':
  88                self.p_func = lambdify((x, xi_internal), expr, 'numpy')
  89                self.symbol = expr
  90            elif mode == 'auto':
  91                if var_u is None:
  92                    raise ValueError("var_u must be provided in mode='auto'")
  93                exp_i = exp(I * x * xi_internal)
  94                P_ei = expr.subs(var_u, exp_i)
  95                symbol = simplify(P_ei / exp_i)
  96                symbol = expand(symbol)
  97                self.symbol = symbol
  98                self.p_func = lambdify((x, xi_internal), symbol, 'numpy')
  99            else:
 100                raise ValueError("mode must be 'auto' or 'symbol'")
 101
 102        elif self.dim == 2:
 103            x, y = vars_x
 104            xi_internal, eta_internal = symbols('xi eta', real=True)
 105            expr = expr.subs(symbols('xi', real=True), xi_internal)
 106            expr = expr.subs(symbols('eta', real=True), eta_internal)
 107            self.fft = partial(fft2, workers=FFT_WORKERS)
 108            self.ifft = partial(ifft2, workers=FFT_WORKERS)
 109
 110            if mode == 'symbol':
 111                self.symbol = expr
 112                self.p_func = lambdify((x, y, xi_internal, eta_internal), expr, 'numpy')
 113            elif mode == 'auto':
 114                if var_u is None:
 115                    raise ValueError("var_u must be provided in mode='auto'")
 116                exp_i = exp(I * (x * xi_internal + y * eta_internal))
 117                P_ei = expr.subs(var_u, exp_i)
 118                symbol = simplify(P_ei / exp_i)
 119                symbol = expand(symbol)
 120                self.symbol = symbol
 121                self.p_func = lambdify((x, y, xi_internal, eta_internal), symbol, 'numpy')
 122            else:
 123                raise ValueError("mode must be 'auto' or 'symbol'")
 124
 125        else:
 126            raise NotImplementedError("Only 1D and 2D supported")
 127
 128        if mode == 'auto':
 129            print("\nsymbol = ")
 130            pprint(self.symbol, num_columns=NUM_COLS)
 131        
 132    def evaluate(self, X, Y, KX, KY, cache=True):
 133        """
 134        Evaluate the pseudo-differential operator's symbol on a grid of spatial and frequency coordinates.
 135
 136        The method dynamically selects between 1D and 2D evaluation based on the spatial dimension.
 137        If caching is enabled and a cached symbol exists, it returns the cached result to avoid recomputation.
 138
 139        Parameters
 140        ----------
 141        X, Y : ndarray
 142            Spatial grid coordinates. In 1D, Y is ignored.
 143        KX, KY : ndarray
 144            Frequency grid coordinates. In 1D, KY is ignored.
 145        cache : bool, default=True
 146            If True, stores the computed symbol for reuse in subsequent calls to avoid redundant computation.
 147
 148        Returns
 149        -------
 150        ndarray
 151            Evaluated symbol values over the input grid. Shape matches the input spatial/frequency grids.
 152
 153        Raises
 154        ------
 155        NotImplementedError
 156            If the spatial dimension is not 1D or 2D.
 157        """
 158        if cache and self.symbol_cached is not None:
 159            return self.symbol_cached
 160
 161        if self.dim == 1:
 162            symbol = self.p_func(X, KX)
 163        elif self.dim == 2:
 164            symbol = self.p_func(X, Y, KX, KY)
 165
 166        if cache:
 167            self.symbol_cached = symbol
 168
 169        return symbol
 170
 171    def clear_cache(self):
 172        """
 173        Clear cached symbol evaluations.
 174        """        
 175        self.symbol_cached = None
 176
 177    def apply(self, u, x_grid, kx, boundary_condition='periodic', 
 178              y_grid=None, ky=None, dealiasing_mask=None,
 179              freq_window='gaussian', clamp=1e6, space_window=False):
 180        """
 181        Apply the pseudo-differential operator to the input field u.
 182    
 183        This method dispatches the application of the pseudo-differential operator based on:
 184        
 185        - Whether the symbol is spatially dependent (x/y)
 186        - The boundary condition in use (periodic or dirichlet)
 187    
 188        Supported operations:
 189        
 190        - Constant-coefficient symbols: applied via Fourier multiplication.
 191        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
 192        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
 193    
 194        Dispatch Logic:\n
 195        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
 196        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
 197        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
 198        
 199        Parameters
 200        ----------
 201        u : ndarray
 202            Function to which the operator is applied
 203        x_grid : ndarray
 204            Spatial grid in x direction
 205        kx : ndarray
 206            Frequency grid in x direction
 207        boundary_condition : str
 208            'periodic' or 'dirichlet'
 209        y_grid : ndarray, optional
 210            Spatial grid in y direction (for 2D)
 211        ky : ndarray, optional
 212            Frequency grid in y direction (for 2D)
 213        dealiasing_mask : ndarray, optional
 214            Dealiasing mask
 215        freq_window : str
 216            Frequency windowing ('gaussian' or 'hann')
 217        clamp : float
 218            Clamp symbol values to [-clamp, clamp]
 219        space_window : bool
 220            Apply spatial windowing
 221            
 222        Returns
 223        -------
 224        ndarray
 225            Result of applying the operator
 226        """
 227        # Check if symbol depends on spatial variables
 228        is_spatial = self._is_spatial_dependent()
 229        
 230        # Case 1: Constant symbol with periodic BC (fast path)
 231        if not is_spatial and boundary_condition == 'periodic':
 232            return self._apply_constant_fft(u, x_grid, kx, y_grid, ky, dealiasing_mask)
 233        
 234        # Case 2: Spatial symbol with periodic BC
 235        elif boundary_condition == 'periodic':
 236            symbol_func = self._get_symbol_func()
 237            return kohn_nirenberg_fft(
 238                u_vals=u,
 239                symbol_func=symbol_func,
 240                x_grid=x_grid,
 241                kx=kx,
 242                fft_func=self.fft,
 243                ifft_func=self.ifft,
 244                dim=self.dim,
 245                y_grid=y_grid,
 246                ky=ky,
 247                freq_window=freq_window,
 248                clamp=clamp,
 249                space_window=space_window
 250            )
 251        
 252        # Case 3: Dirichlet BC (non-periodic)
 253        elif boundary_condition == 'dirichlet':
 254            symbol_func = self._get_symbol_func()
 255            
 256            if self.dim == 1:
 257                return kohn_nirenberg_nonperiodic(
 258                    u_vals=u,
 259                    x_grid=x_grid,
 260                    xi_grid=kx,
 261                    symbol_func=symbol_func,
 262                    freq_window=freq_window,
 263                    clamp=clamp,
 264                    space_window=space_window
 265                )
 266            elif self.dim == 2:
 267                return kohn_nirenberg_nonperiodic(
 268                    u_vals=u,
 269                    x_grid=(x_grid, y_grid),
 270                    xi_grid=(kx, ky),
 271                    symbol_func=symbol_func,
 272                    freq_window=freq_window,
 273                    clamp=clamp,
 274                    space_window=space_window
 275                )
 276        
 277        else:
 278            raise ValueError(f"Invalid boundary condition '{boundary_condition}'")
 279    
 280    def _is_spatial_dependent(self):
 281        """
 282        Check if the symbol depends on spatial variables.
 283        
 284        Returns
 285        -------
 286        bool
 287            True if symbol depends on x (or x, y)
 288        """
 289        if self.dim == 1:
 290            return self.symbol.has(self.vars_x[0])
 291        elif self.dim == 2:
 292            x, y = self.vars_x
 293            return self.symbol.has(x) or self.symbol.has(y)
 294        else:
 295            return False
 296    
 297    def _get_symbol_func(self):
 298        """
 299        Get a lambdified version of the symbol.
 300        
 301        Returns
 302        -------
 303        callable
 304            Lambdified symbol function
 305        """
 306        if self.dim == 1:
 307            x = self.vars_x[0]
 308            xi = symbols('xi', real=True)
 309            return lambdify((x, xi), self.symbol, 'numpy')
 310        elif self.dim == 2:
 311            x, y = self.vars_x
 312            xi, eta = symbols('xi eta', real=True)
 313            return lambdify((x, y, xi, eta), self.symbol, 'numpy')
 314        else:
 315            raise NotImplementedError("Only 1D and 2D supported")
 316    
 317    def _apply_constant_fft(self, u, x_grid, kx, y_grid, ky, dealiasing_mask):
 318        """
 319        Apply a constant-coefficient pseudo-differential operator in Fourier space.
 320
 321        This method assumes the symbol is diagonal in the Fourier basis and acts as a 
 322        multiplication operator. It performs the operation:
 323        
 324            (ψu)(x) = 𝓕⁻¹[ -σ(k) · 𝓕[u](k) ]
 325
 326        where:
 327        - σ(k) is the combined pseudo-differential operator symbol
 328        - 𝓕 denotes the forward Fourier transform
 329        - 𝓕⁻¹ denotes the inverse Fourier transform
 330
 331        The dealiasing mask is applied before returning to physical space.
 332        
 333        Parameters
 334        ----------
 335        u : ndarray
 336            Input function
 337        x_grid : ndarray
 338            Spatial grid (x)
 339        kx : ndarray
 340            Frequency grid (x)
 341        y_grid : ndarray, optional
 342            Spatial grid (y, for 2D)
 343        ky : ndarray, optional
 344            Frequency grid (y, for 2D)
 345        dealiasing_mask : ndarray, optional
 346            Dealiasing mask
 347            
 348        Returns
 349        -------
 350        ndarray
 351            Result
 352        """
 353        u_hat = self.fft(u)
 354        
 355        # Evaluate symbol at grid points
 356        if self.dim == 1:
 357            X_dummy = np.zeros_like(kx)
 358            symbol_vals = self.p_func(X_dummy, kx)
 359        elif self.dim == 2:
 360            KX, KY = np.meshgrid(kx, ky, indexing='ij')
 361            X_dummy = np.zeros_like(KX)
 362            Y_dummy = np.zeros_like(KY)
 363            symbol_vals = self.p_func(X_dummy, Y_dummy, KX, KY)
 364        else:
 365            raise ValueError("Only 1D and 2D supported")
 366        
 367        # Apply symbol
 368        u_hat *= symbol_vals
 369        
 370        # Apply dealiasing
 371        if dealiasing_mask is not None:
 372            u_hat *= dealiasing_mask
 373        
 374        return self.ifft(u_hat)
 375
 376    def principal_symbol(self, order=1):
 377        """
 378        Compute the leading homogeneous component of the pseudo-differential symbol.
 379
 380        This method extracts the principal part of the symbol, which is the dominant 
 381        term under high-frequency asymptotics (|ξ| → ∞). The expansion is performed 
 382        in polar coordinates for 2D symbols to maintain rotational symmetry, then 
 383        converted back to Cartesian form.
 384
 385        Parameters
 386        ----------
 387        order : int
 388            Order of the asymptotic expansion in powers of 1/ρ, where ρ = |ξ| in 1D 
 389            or ρ = sqrt(ξ² + η²) in 2D. Only the leading-order term is returned.
 390
 391        Returns
 392        -------
 393        sympy.Expr
 394            The principal symbol component, homogeneous of degree `m - order`, where 
 395            `m` is the original symbol's order.
 396
 397        Notes:
 398        - In 1D, uses direct series expansion in ξ.
 399        - In 2D, expands in radial variable ρ while preserving angular dependence.
 400        - Useful for microlocal analysis and constructing parametrices.
 401        """
 402
 403        p = self.symbol
 404        if self.dim == 1:
 405            xi = symbols('xi', real=True, positive=True)
 406            return simplify(series(p, xi, oo, n=order).removeO())
 407        elif self.dim == 2:
 408            xi, eta = symbols('xi eta', real=True, positive=True)
 409            # Homogeneous radial expansion: we set (ξ, η) = ρ (cosθ, sinθ)
 410            rho, theta = symbols('rho theta', real=True, positive=True)
 411            p_rho = p.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
 412            expansion = series(p_rho, rho, oo, n=order).removeO()
 413            # Revert back to (ξ, η)
 414            expansion_cart = expansion.subs({rho: sqrt(xi**2 + eta**2),
 415                                             cos(theta): xi / sqrt(xi**2 + eta**2),
 416                                             sin(theta): eta / sqrt(xi**2 + eta**2)})
 417            return simplify(powdenest(expansion_cart, force=True))
 418                       
 419    def is_homogeneous(self, tol=1e-10):
 420        """
 421        Check whether the symbol is homogeneous in the frequency variables.
 422    
 423        Returns
 424        -------
 425        (bool, Rational or float or None)
 426            Tuple (is_homogeneous, degree) where:
 427            - is_homogeneous: True if the symbol satisfies p(λξ, λη) = λ^m * p(ξ, η)
 428            - degree: the detected degree m if homogeneous, or None
 429        """
 430        from sympy import symbols, simplify, expand, Eq
 431        from sympy.abc import l
 432    
 433        if self.dim == 1:
 434            xi = symbols('xi', real=True, positive=True)
 435            l = symbols('l', real=True, positive=True)
 436            p = self.symbol
 437            p_scaled = p.subs(xi, l * xi)
 438            ratio = simplify(p_scaled / p)
 439            if ratio.has(xi):
 440                return False, None
 441            try:
 442                deg = simplify(ratio).as_base_exp()[1]
 443                return True, deg
 444            except Exception:
 445                return False, None
 446    
 447        elif self.dim == 2:
 448            xi, eta = symbols('xi eta', real=True, positive=True)
 449            l = symbols('l', real=True, positive=True)
 450            p = self.symbol
 451            p_scaled = p.subs({xi: l * xi, eta: l * eta})
 452            ratio = simplify(p_scaled / p)
 453            # If ratio == l**m with no (xi, eta) left, it's homogeneous
 454            if ratio.has(xi, eta):
 455                return False, None
 456            try:
 457                base, exp = ratio.as_base_exp()
 458                if base == l:
 459                    return True, exp
 460            except Exception:
 461                pass
 462            return False, None
 463
 464    def symbol_order(self, max_order=10, tol=1e-3):
 465        """
 466        Estimate the homogeneity order of the pseudo-differential symbol in high-frequency asymptotics.
 467    
 468        This method attempts to determine the leading-order behavior of the symbol p(x, ξ) or p(x, y, ξ, η)
 469        as |ξ| → ∞ (in 1D) or |(ξ, η)| → ∞ (in 2D). The returned value represents the asymptotic growth or decay rate,
 470        which is essential for understanding the regularity and mapping properties of the corresponding operator.
 471    
 472        The function uses symbolic preprocessing to ensure proper factorization of frequency variables,
 473        especially in sqrt and power expressions, to avoid erroneous order detection (e.g., due to hidden scaling).
 474    
 475        Parameters
 476        ----------
 477        max_order : int, optional
 478            Maximum number of terms to consider in the series expansion. Default is 10.
 479        tol : float, optional
 480            Tolerance threshold for evaluating the coefficient magnitude. If the coefficient is too small,
 481            the detected order may be discarded. Default is 1e-3.
 482    
 483        Returns
 484        -------
 485        float or None
 486            - If the symbol is homogeneous, returns its exact homogeneity degree as a float.
 487            - Otherwise, estimates the dominant asymptotic order from leading terms in the expansion.
 488            - Returns None if no valid order could be determined.
 489    
 490        Notes
 491        -----
 492        - In 1D:
 493            Two strategies are used:
 494                1. Expand directly in xi at infinity.
 495                2. Substitute xi = 1/z and expand around z = 0.
 496    
 497        - In 2D:
 498            - Transform the symbol into polar coordinates: (xi, eta) = rho*(cos(theta), sin(theta)).
 499            - Expand in rho at infinity, then extract the leading term's power.
 500            - An alternative substitution using 1/z is also tried if the first method fails.
 501    
 502        - Preprocessing steps:
 503            - Sqrt expressions involving frequencies are rewritten to isolate the leading variable.
 504            - Power expressions are factored explicitly to ensure correct symbolic scaling.
 505    
 506        - If the symbol is not homogeneous, a warning is issued, and the result should be interpreted with care.
 507        
 508        - For non-homogeneous symbols, only the principal asymptotic term is considered.
 509    
 510        Raises
 511        ------
 512        NotImplementedError
 513            If the spatial dimension is neither 1 nor 2.
 514        """
 515        from sympy import (
 516            symbols, series, simplify, sqrt, cos, sin, oo, powdenest, radsimp,
 517            expand, expand_power_base
 518        )
 519    
 520        def preprocess_sqrt(expr, freq):
 521            return expr.replace(
 522                lambda e: e.func == sqrt and freq in e.free_symbols,
 523                lambda e: freq * sqrt(1 + (e.args[0] - freq**2) / freq**2)
 524            )
 525    
 526        def preprocess_power(expr, freq):
 527            return expr.replace(
 528                lambda e: e.is_Pow and freq in e.free_symbols,
 529                lambda e: freq**e.exp * (1 + e.base / freq**e.base.as_powers_dict().get(freq, 0))**e.exp
 530            )
 531    
 532        def validate_order(power, coeff, vars_x, tol):
 533            if power is None:
 534                return None
 535            if any(v in coeff.free_symbols for v in vars_x):
 536                print("⚠️ Coefficient depends on spatial variables; ignoring")
 537                return None
 538            try:
 539                coeff_val = abs(float(coeff.evalf()))
 540                if coeff_val < tol:
 541                    print(f"⚠️ Coefficient too small ({coeff_val:.2e} < {tol})")
 542                    return None
 543            except Exception as e:
 544                print(f"⚠️ Coefficient evaluation failed: {e}")
 545                return None
 546            return int(power) if power == int(power) else float(power)
 547    
 548        # Homogeneity check
 549        is_homog, degree = self.is_homogeneous()
 550        if is_homog:
 551            return float(degree)
 552        else:
 553            print("⚠️ The symbol is not homogeneous. The asymptotic order is not well defined.")
 554    
 555        if self.dim == 1:
 556            x = self.vars_x[0]
 557            xi = symbols('xi', real=True, positive=True)
 558    
 559            try:
 560                print("1D symbol_order - method 1")
 561                expr = preprocess_sqrt(self.symbol, xi)
 562                s = series(expr, xi, oo, n=max_order).removeO()
 563                lead = simplify(powdenest(s.as_leading_term(xi), force=True))
 564                power = lead.as_powers_dict().get(xi, None)
 565                coeff = lead / xi**power if power is not None else 0
 566                print("lead =", lead)
 567                print("power =", power)
 568                print("coeff =", coeff)
 569                order = validate_order(power, coeff, [x], tol)
 570                if order is not None:
 571                    return order
 572            except Exception:
 573                pass
 574    
 575            try:
 576                print("1D symbol_order - method 2")
 577                z = symbols('z', real=True, positive=True)
 578                expr_z = preprocess_sqrt(self.symbol.subs(xi, 1/z), 1/z)
 579                s = series(expr_z, z, 0, n=max_order).removeO()
 580                lead = simplify(powdenest(s.as_leading_term(z), force=True))
 581                power = lead.as_powers_dict().get(z, None)
 582                coeff = lead / z**power if power is not None else 0
 583                print("lead =", lead)
 584                print("power =", power)
 585                print("coeff =", coeff)
 586                order = validate_order(power, coeff, [x], tol)
 587                if order is not None:
 588                    return -order
 589            except Exception as e:
 590                print(f"⚠️ fallback z failed: {e}")
 591            return None
 592    
 593        elif self.dim == 2:
 594            x, y = self.vars_x
 595            xi, eta = symbols('xi eta', real=True, positive=True)
 596            rho, theta = symbols('rho theta', real=True, positive=True)
 597    
 598            try:
 599                print("2D symbol_order - method 1")
 600                p_rho = self.symbol.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
 601                p_rho = preprocess_power(preprocess_sqrt(p_rho, rho), rho)
 602                s = series(simplify(p_rho), rho, oo, n=max_order).removeO()
 603                lead = radsimp(simplify(powdenest(s.as_leading_term(rho), force=True)))
 604                power = lead.as_powers_dict().get(rho, None)
 605                coeff = lead / rho**power if power is not None else 0
 606                print("lead =", lead)
 607                print("power =", power)
 608                print("coeff =", coeff)
 609                order = validate_order(power, coeff, [x, y], tol)
 610                if order is not None:
 611                    return order
 612            except Exception as e:
 613                print(f"⚠️ polar expansion failed: {e}")
 614    
 615            try:
 616                print("2D symbol_order - method 2")
 617                z = symbols('z', real=True, positive=True)
 618                xi_eta = {xi: (1/z) * cos(theta), eta: (1/z) * sin(theta)}
 619                p_rho = preprocess_sqrt(self.symbol.subs(xi_eta), 1/z)
 620                s = series(simplify(p_rho), z, 0, n=max_order).removeO()
 621                lead = radsimp(simplify(powdenest(s.as_leading_term(z), force=True)))
 622                power = lead.as_powers_dict().get(z, None)
 623                coeff = lead / z**power if power is not None else 0
 624                print("lead =", lead)
 625                print("power =", power)
 626                print("coeff =", coeff)
 627                order = validate_order(power, coeff, [x, y], tol)
 628                if order is not None:
 629                    return -order
 630            except Exception as e:
 631                print(f"⚠️ fallback z (2D) failed: {e}")
 632            return None
 633    
 634        else:
 635            raise NotImplementedError("Only 1D and 2D supported.")
 636
 637    
 638    def asymptotic_expansion(self, order=3):
 639        """
 640        Compute the asymptotic expansion of the symbol as |ξ| → ∞ (high-frequency regime).
 641    
 642        This method expands the pseudo-differential symbol in inverse powers of the 
 643        frequency variable(s), either in 1D or 2D. It handles both polynomial and 
 644        exponential symbols by performing a series expansion in 1/|ξ| up to the specified order.
 645    
 646        The expansion is performed directly in Cartesian coordinates for 1D symbols.
 647        For 2D symbols, the method uses polar coordinates (ρ, θ) to perform the expansion 
 648        at infinity in ρ, then converts the result back to Cartesian coordinates.
 649    
 650        Parameters
 651        ----------
 652        order : int, optional
 653            Maximum order of the asymptotic expansion. Default is 3.
 654    
 655        Returns
 656        -------
 657        sympy.Expr
 658            The asymptotic expansion of the symbol up to the given order, expressed in Cartesian coordinates.
 659            If expansion fails, returns the original unexpanded symbol.
 660    
 661        Notes:
 662        - In 1D: expansion is performed directly in terms of ξ.
 663        - In 2D: the symbol is first rewritten in polar coordinates (ρ,θ), expanded asymptotically 
 664          in ρ → ∞, then converted back to Cartesian coordinates (ξ,η).
 665        - Handles special case when the symbol is an exponential function by expanding its argument.
 666        - Symbolic normalization is applied early (via `simplify`) for 2D expressions to improve convergence.
 667        - Robust to failures: catches exceptions and issues warnings instead of raising errors.
 668        - Final expression is simplified using `powdenest` and `expand` for improved readability.
 669        """
 670        p = self.symbol
 671    
 672        if self.dim == 1:
 673            xi = symbols('xi', real=True, positive=True)
 674    
 675            try:
 676                # Case: exponential function
 677                if p.func == exp and len(p.args) == 1:
 678                    arg = p.args[0]
 679                    arg_series = series(arg, xi, oo, n=order).removeO()
 680                    expanded = series(exp(expand(arg_series)), xi, oo, n=order).removeO()
 681                    return simplify(powdenest(expanded, force=True))
 682                else:
 683                    expanded = series(p, xi, oo, n=order).removeO()
 684                    return simplify(powdenest(expanded, force=True))
 685    
 686            except Exception as e:
 687                print(f"Warning: 1D expansion failed: {e}")
 688                return p
 689    
 690        elif self.dim == 2:
 691            xi, eta = symbols('xi eta', real=True, positive=True)
 692            rho, theta = symbols('rho theta', real=True, positive=True)
 693    
 694            # Normalize before substitution
 695            p = simplify(p)
 696    
 697            # Substitute polar coordinates
 698            p_polar = p.subs({
 699                xi: rho * cos(theta),
 700                eta: rho * sin(theta)
 701            })
 702    
 703            try:
 704                # Handle exponentials
 705                if p_polar.func == exp and len(p_polar.args) == 1:
 706                    arg = p_polar.args[0]
 707                    arg_series = series(arg, rho, oo, n=order).removeO()
 708                    expanded = series(exp(expand(arg_series)), rho, oo, n=order).removeO()
 709                else:
 710                    expanded = series(p_polar, rho, oo, n=order).removeO()
 711    
 712                # Convert back to Cartesian
 713                norm = sqrt(xi**2 + eta**2)
 714                expansion_cart = expanded.subs({
 715                    rho: norm,
 716                    cos(theta): xi / norm,
 717                    sin(theta): eta / norm
 718                })
 719    
 720                # Final simplifications
 721                result = simplify(powdenest(expansion_cart, force=True))
 722                result = expand(result)
 723                return result
 724    
 725            except Exception as e:
 726                print(f"Warning: 2D expansion failed: {e}")
 727                return p  
 728            
 729    def compose_asymptotic(self, other, order=1, mode='kn', sign_convention=None):
 730        """
 731        Compose two pseudo-differential operators using an asymptotic expansion
 732        in the chosen quantization scheme (Kohn–Nirenberg or Weyl).
 733    
 734        Parameters
 735        ----------
 736        other : PseudoDifferentialOperator
 737            The operator to compose with this one.
 738        order : int, default=1
 739            Maximum order of the asymptotic expansion.
 740        mode : {'kn', 'weyl'}, default='kn'
 741            Quantization mode:
 742            - 'kn' : Kohn–Nirenberg quantization (left-quantized)
 743            - 'weyl' : Weyl symmetric quantization
 744        sign_convention : {'standard', 'inverse'}, optional
 745            Controls the phase factor convention for the KN case:
 746            - 'standard' → (i)^(-n), gives [x, ξ] = +i (physics convention)
 747            - 'inverse' → (i)^(+n), gives [x, ξ] = -i (mathematical adjoint convention)
 748            If None, defaults to 'standard'.
 749    
 750        Returns
 751        -------
 752        sympy.Expr
 753            Symbolic expression for the composed symbol up to the given order.
 754    
 755        Notes
 756        -----
 757        - In 1D (Kohn–Nirenberg):
 758            (p ∘ q)(x, ξ) ~ Σₙ (1/n!) (i sgn)^n ∂_ξⁿ p(x, ξ) ∂_xⁿ q(x, ξ)
 759        - In 1D (Weyl):
 760            (p # q)(x, ξ) = exp[(i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q)] p(x, ξ) q(x, ξ)
 761            truncated at given order.
 762    
 763        Examples
 764        --------
 765        X = a*x, Y = b*ξ
 766        X_op.compose_asymptotic(Y_op, order=3, mode='weyl')
 767        """
 768    
 769        from sympy import diff, factorial, simplify, symbols
 770    
 771        assert self.dim == other.dim, "Operator dimensions must match"
 772        p, q = self.symbol, other.symbol
 773    
 774        # Default sign convention
 775        if sign_convention is None:
 776            sign_convention = 'standard'
 777        sign = -1 if sign_convention == 'standard' else +1
 778    
 779        # --- 1D case ---
 780        if self.dim == 1:
 781            x = self.vars_x[0]
 782            xi = symbols('xi', real=True)
 783            result = 0
 784    
 785            if mode == 'kn':  # Kohn–Nirenberg
 786                for n in range(order + 1):
 787                    term = (1 / factorial(n)) * diff(p, xi, n) * diff(q, x, n) * (1j) ** (sign * n)
 788                    result += term
 789    
 790            elif mode == 'weyl':  # Weyl symmetric composition
 791                # Weyl star product: exp((i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q))
 792                result = 0
 793                for n in range(order + 1):
 794                    for k in range(n + 1):
 795                        # k derivatives acting as (∂_ξ^k p)(∂_x^(n−k) q)
 796                        coeff = (1 / (factorial(k) * factorial(n - k))) * ((1j / 2) ** n) * ((-1) ** (n - k))
 797                        term = coeff * diff(p, xi, k, x, n - k, evaluate=True) * diff(q, x, k, xi, n - k, evaluate=True)
 798                        result += term
 799    
 800            else:
 801                raise ValueError("mode must be either 'kn' or 'weyl'")
 802    
 803            return simplify(result)
 804    
 805        # --- 2D case ---
 806        elif self.dim == 2:
 807            x, y = self.vars_x
 808            xi, eta = symbols('xi eta', real=True)
 809            result = 0
 810    
 811            if mode == 'kn':
 812                for n in range(order + 1):
 813                    for i in range(n + 1):
 814                        j = n - i
 815                        term = (1 / (factorial(i) * factorial(j))) * \
 816                               diff(p, xi, i, eta, j) * diff(q, x, i, y, j) * (1j) ** (sign * n)
 817                        result += term
 818    
 819            elif mode == 'weyl':
 820                for n in range(order + 1):
 821                    for i in range(n + 1):
 822                        j = n - i
 823                        coeff = (1 / (factorial(i) * factorial(j))) * ((1j / 2) ** n) * ((-1) ** (n - i))
 824                        term = coeff * diff(p, xi, i, eta, j, x, 0, y, 0) * diff(q, x, i, y, j, xi, 0, eta, 0)
 825                        result += term
 826            else:
 827                raise ValueError("mode must be either 'kn' or 'weyl'")
 828    
 829            return simplify(result)
 830    
 831        else:
 832            raise NotImplementedError("Only 1D and 2D cases are implemented")
 833
 834    def commutator_symbolic(self, other, order=1, mode='kn', sign_convention=None):
 835        """
 836        Compute the symbolic commutator [A, B] = A∘B − B∘A of two pseudo-differential operators
 837        using formal asymptotic expansion of their composition symbols.
 838    
 839        This method computes the asymptotic expansion of the commutator's symbol up to a given 
 840        order, based on the symbolic calculus of pseudo-differential operators in the 
 841        Kohn–Nirenberg quantization. The result is a purely symbolic sympy expression that 
 842        captures the leading-order noncommutativity of the operators.
 843    
 844        Parameters
 845        ----------
 846        other : PseudoDifferentialOperator
 847            The pseudo-differential operator B to commute with this operator A.
 848        order : int, default=1
 849            Maximum order of the asymptotic expansion. 
 850            - order=1 yields the leading term proportional to the Poisson bracket {p, q}.
 851            - Higher orders include correction terms involving higher mixed derivatives.
 852    
 853        Returns
 854        -------
 855        sympy.Expr
 856            Symbolic expression for the asymptotic expansion of the commutator symbol 
 857            σ([A,B]) = σ(A∘B − B∘A).
 858    
 859        """
 860        assert self.dim == other.dim, "Operator dimensions must match"
 861        p, q = self.symbol, other.symbol
 862    
 863        pq = self.compose_asymptotic(other, order=order, mode=mode, sign_convention=sign_convention)
 864        qp = other.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
 865        
 866        comm_symbol = simplify(pq-qp)
 867
 868        return comm_symbol
 869
 870    def right_inverse_asymptotic(self, order=1):
 871        """
 872        Construct a formal right inverse R of the pseudo-differential operator P such that 
 873        the composition P ∘ R equals the identity plus a smoothing operator of order -order.
 874    
 875        This method computes an asymptotic expansion for the right inverse using recursive 
 876        corrections based on derivatives of the symbol p(x, ξ) and lower-order terms of R.
 877    
 878        Parameters
 879        ----------
 880        order : int
 881            Number of terms to include in the asymptotic expansion. Higher values improve 
 882            approximation at the cost of complexity and computational effort.
 883    
 884        Returns
 885        -------
 886        sympy.Expr
 887            The symbolic expression representing the formal right inverse R(x, ξ), which satisfies:
 888            P ∘ R = Id + O(⟨ξ⟩^{-order}), where ⟨ξ⟩ = (1 + |ξ|²)^{1/2}.
 889    
 890        Notes
 891        -----
 892        - In 1D: The recursion involves spatial derivatives of R and derivatives of p with respect to ξ.
 893        - In 2D: The multi-index generalization is used with mixed derivatives in ξ and η.
 894        - The construction relies on the non-vanishing of the principal symbol p to ensure invertibility.
 895        - Each term in the expansion corresponds to higher-order corrections involving commutators 
 896          between the operator P and the current approximation of R.
 897        """
 898        p = self.symbol
 899        if self.dim == 1:
 900            x = self.vars_x[0]
 901            xi = symbols('xi', real=True)
 902            r = 1 / p.subs(xi, xi)  # r0
 903            R = r
 904            for n in range(1, order + 1):
 905                term = 0
 906                for k in range(1, n + 1):
 907                    coeff = (1j)**(-k) / factorial(k)
 908                    inner = diff(p, xi, k) * diff(R, x, k)
 909                    term += coeff * inner
 910                R = R - r * term
 911        elif self.dim == 2:
 912            x, y = self.vars_x
 913            xi, eta = symbols('xi eta', real=True)
 914            r = 1 / p.subs({xi: xi, eta: eta})
 915            R = r
 916            for n in range(1, order + 1):
 917                term = 0
 918                for k1 in range(n + 1):
 919                    for k2 in range(n + 1 - k1):
 920                        if k1 + k2 == 0: continue
 921                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
 922                        dp = diff(p, xi, k1, eta, k2)
 923                        dR = diff(R, x, k1, y, k2)
 924                        term += coeff * dp * dR
 925                R = R - r * term
 926        return R
 927
 928    def left_inverse_asymptotic(self, order=1):
 929        """
 930        Construct a formal left inverse L such that the composition L ∘ P equals the identity 
 931        operator up to terms of order ξ^{-order}. This expansion is performed asymptotically 
 932        at infinity in the frequency variable(s).
 933    
 934        The left inverse is built iteratively using symbolic differentiation and the 
 935        method of asymptotic expansions for pseudo-differential operators. It ensures that:
 936        
 937            L(P(x,ξ),x,D) ∘ P(x,D) = Id + smoothing operator of order -order
 938    
 939        Parameters
 940        ----------
 941        order : int, optional
 942            Maximum number of terms in the asymptotic expansion (default is 1). Higher values 
 943            yield more accurate inverses at the cost of increased computational complexity.
 944    
 945        Returns
 946        -------
 947        sympy.Expr
 948            Symbolic expression representing the principal symbol of the formal left inverse 
 949            operator L(x,ξ). This expression depends on spatial variables and frequencies, 
 950            and includes correction terms up to the specified order.
 951    
 952        Notes
 953        -----
 954        - In 1D: Uses recursive application of the Leibniz formula for symbols.
 955        - In 2D: Generalizes to multi-indices for mixed derivatives in (x,y) and (ξ,η).
 956        - Each term involves combinations of derivatives of the original symbol p(x,ξ) and 
 957          previously computed terms of the inverse.
 958        - Coefficients include powers of 1j (i) and factorial normalization for derivative terms.
 959        """
 960        p = self.symbol
 961        if self.dim == 1:
 962            x = self.vars_x[0]
 963            xi = symbols('xi', real=True)
 964            l = 1 / p.subs(xi, xi)
 965            L = l
 966            for n in range(1, order + 1):
 967                term = 0
 968                for k in range(1, n + 1):
 969                    coeff = (1j)**(-k) / factorial(k)
 970                    inner = diff(L, xi, k) * diff(p, x, k)
 971                    term += coeff * inner
 972                L = L - term * l
 973        elif self.dim == 2:
 974            x, y = self.vars_x
 975            xi, eta = symbols('xi eta', real=True)
 976            l = 1 / p.subs({xi: xi, eta: eta})
 977            L = l
 978            for n in range(1, order + 1):
 979                term = 0
 980                for k1 in range(n + 1):
 981                    for k2 in range(n + 1 - k1):
 982                        if k1 + k2 == 0: continue
 983                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
 984                        dp = diff(p, x, k1, y, k2)
 985                        dL = diff(L, xi, k1, eta, k2)
 986                        term += coeff * dL * dp
 987                L = L - term * l
 988        return L
 989
 990    def formal_adjoint(self):
 991        """
 992        Compute the formal adjoint symbol P* of the pseudo-differential operator.
 993
 994        The adjoint is defined such that for any test functions u and v,
 995        ⟨P u, v⟩ = ⟨u, P* v⟩ holds in the distributional sense. This is obtained by 
 996        taking the complex conjugate of the symbol and expanding it asymptotically 
 997        at infinity to ensure proper behavior under integration by parts.
 998
 999        Returns
1000        -------
1001        sympy.Expr
1002            The adjoint symbol P*(x, ξ) in 1D or P*(x, y, ξ, η) in 2D.
1003        
1004        Notes:
1005        - In 1D, the expansion is performed in powers of 1/|ξ|.
1006        - In 2D, the expansion is radial in |ξ| = sqrt(ξ² + η²).
1007        - This method ensures symbolic simplifications for readability and efficiency.
1008        """
1009        p = self.symbol
1010        if self.dim == 1:
1011            x, = self.vars_x
1012            xi = symbols('xi', real=True)
1013            p_star = conjugate(p)
1014            p_star = simplify(series(p_star, xi, oo, n=6).removeO())
1015            return p_star
1016        elif self.dim == 2:
1017            x, y = self.vars_x
1018            xi, eta = symbols('xi eta', real=True)
1019            p_star = conjugate(p)
1020            p_star = simplify(series(p_star, sqrt(xi**2 + eta**2), oo, n=6).removeO())
1021            return p_star
1022
1023    def exponential_symbol(self, t=1.0, order=1, mode='kn', sign_convention=None):
1024        """
1025        Compute the symbol of exp(tP) using asymptotic expansion methods.
1026        
1027        This method calculates the exponential of a pseudo-differential operator 
1028        using either a direct power series expansion or a Magnus expansion, 
1029        depending on the structure of the symbol. The result is valid up to 
1030        the specified asymptotic order.
1031        
1032        Parameters
1033        ----------
1034        t : float or sympy.Symbol, default=1.0
1035            Time or evolution parameter. Common uses:
1036            - t = -i*τ for Schrödinger evolution: exp(-iτH)
1037            - t = τ for heat/diffusion: exp(τΔ)
1038            - t for general propagators
1039        order : int, default=3
1040            Maximum order of the asymptotic expansion. Higher orders include 
1041            more composition terms, improving accuracy for small t or when 
1042            non-commutativity effects are significant.
1043        
1044        Returns
1045        -------
1046        sympy.Expr
1047            Symbolic expression for the exponential operator symbol, computed 
1048            as an asymptotic series up to the specified order.
1049        
1050        Notes
1051        -----
1052        - For commutative symbols (e.g., pure multiplication operators), the 
1053          exponential is exact: exp(tP) = exp(t*p(x,ξ)).
1054        
1055        - For general non-commutative operators, the method uses the BCH-type 
1056          expansion via iterated composition:
1057          exp(tP) ~ I + tP + (t²/2!)P∘P + (t³/3!)P∘P∘P + ...
1058          
1059        - Each power P^n is computed via compose_asymptotic, which accounts 
1060          for the non-commutativity through derivative terms.
1061        
1062        - The expansion is valid for |t| small enough or when the symbol has 
1063          appropriate decay/growth properties.
1064        
1065        - In quantum mechanics (Schrödinger): U(t) = exp(-itH/ℏ) represents 
1066          the time evolution operator.
1067        
1068        - In parabolic PDEs (heat equation): exp(tΔ) is the heat kernel.
1069
1070        """
1071        if self.dim == 1:
1072            x = self.vars_x[0]
1073            xi = symbols('xi', real=True)
1074            
1075            # Initialize with identity
1076            result = 1
1077            
1078            # First order term: tP
1079            current_power = self.symbol
1080            result += t * current_power
1081            
1082            # Higher order terms: (t^n/n!) P^n computed via composition
1083            for n in range(2, order + 1):
1084                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1085                # We use a temporary operator for composition
1086                temp_op = PseudoDifferentialOperator(
1087                    current_power, [x], mode='symbol'
1088                )
1089                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1090                
1091                # Add term (t^n/n!) * P^n
1092                coeff = t**n / factorial(n)
1093                result += coeff * current_power
1094            
1095            return simplify(result)
1096        
1097        elif self.dim == 2:
1098            x, y = self.vars_x
1099            xi, eta = symbols('xi eta', real=True)
1100            
1101            # Initialize with identity
1102            result = 1
1103            
1104            # First order term: tP
1105            current_power = self.symbol
1106            result += t * current_power
1107            
1108            # Higher order terms: (t^n/n!) P^n computed via composition
1109            for n in range(2, order + 1):
1110                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1111                temp_op = PseudoDifferentialOperator(
1112                    current_power, [x, y], mode='symbol'
1113                )
1114                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1115                
1116                # Add term (t^n/n!) * P^n
1117                coeff = t**n / factorial(n)
1118                result += coeff * current_power
1119            
1120            return simplify(result)
1121        
1122        else:
1123            raise NotImplementedError("Only 1D and 2D operators are supported")
1124        
1125    def trace_formula(self, volume_element=None, numerical=False, 
1126                      x_bounds=None, xi_bounds=None):
1127        """
1128        Compute the semiclassical trace of the pseudo-differential operator.
1129        
1130        The trace formula relates the quantum trace of an operator to a 
1131        phase-space integral of its symbol, providing a fundamental link 
1132        between classical and quantum mechanics. This implementation supports 
1133        both symbolic and numerical integration.
1134        
1135        Parameters
1136        ----------
1137        volume_element : sympy.Expr, optional
1138            Custom volume element for the phase space integration. If None, 
1139            uses the standard Liouville measure dx dξ/(2π)^d.
1140        numerical : bool, default=False
1141            If True, perform numerical integration over specified bounds.
1142            If False, attempt symbolic integration (may fail for complex symbols).
1143        x_bounds : tuple of tuples, optional
1144            Spatial integration bounds. For 1D: ((x_min, x_max),)
1145            For 2D: ((x_min, x_max), (y_min, y_max))
1146            Required if numerical=True.
1147        xi_bounds : tuple of tuples, optional
1148            Frequency integration bounds. For 1D: ((xi_min, xi_max),)
1149            For 2D: ((xi_min, xi_max), (eta_min, eta_max))
1150            Required if numerical=True.
1151        
1152        Returns
1153        -------
1154        sympy.Expr or float
1155            The trace of the operator. Returns a symbolic expression if 
1156            numerical=False, or a float if numerical=True.
1157        
1158        Notes
1159        -----
1160        - The semiclassical trace formula states:
1161          Tr(P) = (2π)^{-d} ∫∫ p(x,ξ) dx dξ
1162          where d is the spatial dimension and p(x,ξ) is the operator symbol.
1163        
1164        - For 1D: Tr(P) = (1/2π) ∫_{-∞}^{∞} ∫_{-∞}^{∞} p(x,ξ) dx dξ
1165        
1166        - For 2D: Tr(P) = (1/4π²) ∫∫∫∫ p(x,y,ξ,η) dx dy dξ dη
1167        
1168        - This formula is exact for trace-class operators and provides an 
1169          asymptotic approximation for general pseudo-differential operators.
1170        
1171        - Physical interpretation: the trace counts the "number of states" 
1172          weighted by the observable p(x,ξ).
1173        
1174        - For projection operators (χ_Ω with χ² = χ), the trace gives the 
1175          dimension of the range, related to the phase space volume of Ω.
1176        
1177        - The factor (2π)^{-d} comes from the quantum normalization of 
1178          coherent states / Weyl quantization.
1179        """
1180        from sympy import integrate, simplify, lambdify
1181        from scipy.integrate import dblquad, nquad
1182        
1183        p = self.symbol
1184        
1185        if numerical:
1186            if x_bounds is None or xi_bounds is None:
1187                raise ValueError(
1188                    "x_bounds and xi_bounds must be provided for numerical integration"
1189                )
1190        
1191        if self.dim == 1:
1192            x, = self.vars_x
1193            xi = symbols('xi', real=True)
1194            
1195            if volume_element is None:
1196                volume_element = 1 / (2 * pi)
1197            
1198            if numerical:
1199                # Numerical integration
1200                p_func = lambdify((x, xi), p, 'numpy')
1201                (x_min, x_max), = x_bounds
1202                (xi_min, xi_max), = xi_bounds
1203                
1204                def integrand(xi_val, x_val):
1205                    return p_func(x_val, xi_val)
1206                
1207                result, error = dblquad(
1208                    integrand,
1209                    x_min, x_max,
1210                    lambda x: xi_min, lambda x: xi_max
1211                )
1212                
1213                result *= float(volume_element)
1214                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1215                return result
1216            
1217            else:
1218                # Symbolic integration
1219                integrand = p * volume_element
1220                
1221                try:
1222                    # Try to integrate over xi first, then x
1223                    integral_xi = integrate(integrand, (xi, -oo, oo))
1224                    integral_x = integrate(integral_xi, (x, -oo, oo))
1225                    return simplify(integral_x)
1226                except:
1227                    print("Warning: Symbolic integration failed. Try numerical=True")
1228                    return integrate(integrand, (xi, -oo, oo), (x, -oo, oo))
1229        
1230        elif self.dim == 2:
1231            x, y = self.vars_x
1232            xi, eta = symbols('xi eta', real=True)
1233            
1234            if volume_element is None:
1235                volume_element = 1 / (4 * pi**2)
1236            
1237            if numerical:
1238                # Numerical integration in 4D
1239                p_func = lambdify((x, y, xi, eta), p, 'numpy')
1240                (x_min, x_max), (y_min, y_max) = x_bounds
1241                (xi_min, xi_max), (eta_min, eta_max) = xi_bounds
1242                
1243                def integrand(eta_val, xi_val, y_val, x_val):
1244                    return p_func(x_val, y_val, xi_val, eta_val)
1245                
1246                result, error = nquad(
1247                    integrand,
1248                    [
1249                        [eta_min, eta_max],
1250                        [xi_min, xi_max],
1251                        [y_min, y_max],
1252                        [x_min, x_max]
1253                    ]
1254                )
1255                
1256                result *= float(volume_element)
1257                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1258                return result
1259            
1260            else:
1261                # Symbolic integration
1262                integrand = p * volume_element
1263                
1264                try:
1265                    # Integrate in order: eta, xi, y, x
1266                    integral_eta = integrate(integrand, (eta, -oo, oo))
1267                    integral_xi = integrate(integral_eta, (xi, -oo, oo))
1268                    integral_y = integrate(integral_xi, (y, -oo, oo))
1269                    integral_x = integrate(integral_y, (x, -oo, oo))
1270                    return simplify(integral_x)
1271                except:
1272                    print("Warning: Symbolic integration failed. Try numerical=True")
1273                    return integrate(
1274                        integrand,
1275                        (eta, -oo, oo), (xi, -oo, oo),
1276                        (y, -oo, oo), (x, -oo, oo)
1277                    )
1278        
1279        else:
1280            raise NotImplementedError("Only 1D and 2D operators are supported")
1281
1282    def symplectic_flow(self):
1283        """
1284        Compute the Hamiltonian vector field associated with the principal symbol.
1285
1286        This method derives the canonical equations of motion for the phase space variables 
1287        (x, ξ) in 1D or (x, y, ξ, η) in 2D, based on the Hamiltonian formalism. These describe 
1288        how position and frequency variables evolve under the flow generated by the symbol.
1289
1290        Returns
1291        -------
1292        dict
1293            A dictionary containing the components of the Hamiltonian vector field:
1294            - In 1D: keys are 'dx/dt' and 'dxi/dt', corresponding to dx/dt = ∂p/∂ξ and dξ/dt = -∂p/∂x.
1295            - In 2D: keys are 'dx/dt', 'dy/dt', 'dxi/dt', and 'deta/dt', with similar definitions:
1296              dx/dt = ∂p/∂ξ, dy/dt = ∂p/∂η, dξ/dt = -∂p/∂x, dη/dt = -∂p/∂y.
1297
1298        Notes
1299        -----
1300        - The Hamiltonian here is the principal symbol p(x, ξ) itself.
1301        - This flow preserves the symplectic structure of phase space.
1302        """
1303        if self.dim == 1:
1304            x,  = self.vars_x
1305            xi = symbols('xi', real=True)
1306            return {
1307                'dx/dt': diff(self.symbol, xi),
1308                'dxi/dt': -diff(self.symbol, x)
1309            }
1310        elif self.dim == 2:
1311            x, y = self.vars_x
1312            xi, eta = symbols('xi eta', real=True)
1313            return {
1314                'dx/dt': diff(self.symbol, xi),
1315                'dy/dt': diff(self.symbol, eta),
1316                'dxi/dt': -diff(self.symbol, x),
1317                'deta/dt': -diff(self.symbol, y)
1318            }
1319
1320    def is_elliptic_numerically(self, x_grid, xi_grid, threshold=1e-8):
1321        """
1322        Check if the pseudo-differential symbol p(x, ξ) is elliptic over a given grid.
1323    
1324        A symbol is considered elliptic if its magnitude |p(x, ξ)| remains bounded away from zero 
1325        across all points in the spatial-frequency domain. This method evaluates the symbol on a 
1326        grid of spatial and frequency coordinates and checks whether its minimum absolute value 
1327        exceeds a specified threshold.
1328    
1329        Resampling is applied to large grids to prevent excessive memory usage, particularly in 2D.
1330    
1331        Parameters
1332        ----------
1333        x_grid : ndarray
1334            Spatial grid: either a 1D array (x) or a tuple of two 1D arrays (x, y).
1335        xi_grid : ndarray
1336            Frequency grid: either a 1D array (ξ) or a tuple of two 1D arrays (ξ, η).
1337        threshold : float, optional
1338            Minimum acceptable value for |p(x, ξ)|. If the smallest evaluated symbol value falls below this,
1339            the symbol is not considered elliptic.
1340    
1341        Returns
1342        -------
1343        bool
1344            True if the symbol is elliptic on the resampled grid, False otherwise.
1345        """
1346        RESAMPLE_SIZE = 32  # Reduced size to prevent memory explosion
1347        
1348        if self.dim == 1:
1349            x_vals = x_grid
1350            xi_vals = xi_grid
1351            # Resampling if necessary
1352            if len(x_vals) > RESAMPLE_SIZE:
1353                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1354            if len(xi_vals) > RESAMPLE_SIZE:
1355                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1356        
1357            X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1358            symbol_vals = self.p_func(X, XI)
1359        
1360        elif self.dim == 2:
1361            x_vals, y_vals = x_grid
1362            xi_vals, eta_vals = xi_grid
1363        
1364            # Spatial resampling
1365            if len(x_vals) > RESAMPLE_SIZE:
1366                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1367            if len(y_vals) > RESAMPLE_SIZE:
1368                y_vals = np.linspace(y_vals.min(), y_vals.max(), RESAMPLE_SIZE)
1369        
1370            # Frequency resampling
1371            if len(xi_vals) > RESAMPLE_SIZE:
1372                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1373            if len(eta_vals) > RESAMPLE_SIZE:
1374                eta_vals = np.linspace(eta_vals.min(), eta_vals.max(), RESAMPLE_SIZE)
1375        
1376            X, Y, XI, ETA = np.meshgrid(x_vals, y_vals, xi_vals, eta_vals, indexing='ij')
1377            symbol_vals = self.p_func(X, Y, XI, ETA)
1378        
1379        min_abs_val = np.min(np.abs(symbol_vals))
1380        return min_abs_val > threshold
1381
1382
1383    def is_self_adjoint(self, tol=1e-10):
1384        """
1385        Check whether the pseudo-differential operator is formally self-adjoint (Hermitian).
1386
1387        A self-adjoint operator satisfies P = P*, where P* is the formal adjoint of P.
1388        This property is essential for ensuring real-valued eigenvalues and stable evolution 
1389        in quantum mechanics and symmetric wave propagation.
1390
1391        Parameters
1392        ----------
1393        tol : float
1394            Tolerance for symbolic comparison between P and P*. Small numerical differences 
1395            below this threshold are considered equal.
1396
1397        Returns
1398        -------
1399        bool
1400            True if the symbol p(x, ξ) equals its formal adjoint p*(x, ξ) within the given tolerance,
1401            indicating that the operator is self-adjoint.
1402
1403        Notes:
1404        - The formal adjoint is computed via conjugation and asymptotic expansion at infinity in ξ.
1405        - Symbolic simplification is used to verify equality, ensuring robustness against superficial 
1406          expression differences.
1407        """
1408        p = self.symbol
1409        p_star = self.formal_adjoint()
1410        return simplify(p - p_star).equals(0)
1411
1412    def visualize_fiber(self, x_grid, xi_grid, x0=0.0, y0=0.0):
1413        """
1414        Plot the cotangent fiber structure at a fixed spatial point (x₀[, y₀]).
1415    
1416        This visualization shows how the symbol p(x, ξ) behaves on the cotangent fiber 
1417        above a fixed spatial point. In microlocal analysis, this provides insight into 
1418        the frequency content of the operator at that location.
1419    
1420        Parameters
1421        ----------
1422        x_grid : ndarray
1423            Spatial grid values (1D) for evaluation in 1D case.
1424        xi_grid : ndarray
1425            Frequency grid values (1D) for evaluation in both 1D and 2D cases.
1426        x0 : float, optional
1427            Fixed x-coordinate of the base point in space (1D or 2D).
1428        y0 : float, optional
1429            Fixed y-coordinate of the base point in space (2D only).
1430    
1431        Notes
1432        -----
1433        - In 1D: Displays |p(x, ξ)| over the (x, ξ) phase plane near the fixed point.
1434        - In 2D: Fixes (x₀, y₀) and evaluates p(x₀, y₀, ξ, η), showing the fiber over that point.
1435        - The color map represents the magnitude of the symbol, highlighting regions where it vanishes or becomes singular.
1436    
1437        Raises
1438        ------
1439        NotImplementedError
1440            If called in 2D with missing or improperly formatted grids.
1441        """
1442        if self.dim == 1:
1443            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1444            symbol_vals = self.p_func(X, XI)
1445            plt.contourf(X, XI, np.abs(symbol_vals), levels=50, cmap='viridis')
1446            plt.colorbar(label='|Symbol|')
1447            plt.xlabel('x (position)')
1448            plt.ylabel('ξ (frequency)')
1449            plt.title('Cotangent Fiber Structure')
1450            plt.show()
1451        elif self.dim == 2:
1452            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, xi_grid)
1453            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1454            plt.contourf(xi_grid, xi_grid, np.abs(symbol_vals), levels=50, cmap='viridis')
1455            plt.colorbar(label='|Symbol|')
1456            plt.xlabel('ξ')
1457            plt.ylabel('η')
1458            plt.title(f'Cotangent Fiber at x={x0}, y={y0}')
1459            plt.show()
1460
1461    def visualize_symbol_amplitude(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1462        """
1463        Display the modulus |p(x, ξ)| or |p(x, y, ξ₀, η₀)| as a color map.
1464    
1465        This method visualizes the amplitude of the pseudodifferential operator's symbol 
1466        in either 1D or 2D spatial configuration. In 2D, the frequency variables are fixed 
1467        to specified values (ξ₀, η₀) for visualization purposes.
1468    
1469        Parameters
1470        ----------
1471        x_grid, y_grid : ndarray
1472            Spatial grids over which to evaluate the symbol. y_grid is optional and used only in 2D.
1473        xi_grid, eta_grid : ndarray
1474            Frequency grids. In 2D, these define the domain over which the symbol is evaluated,
1475            but the visualization fixes ξ = ξ₀ and η = η₀.
1476        xi0, eta0 : float, optional
1477            Fixed frequency values for slicing in 2D visualization. Defaults to zero.
1478    
1479        Notes
1480        -----
1481        - In 1D: Visualizes |p(x, ξ)| over the (x, ξ) grid.
1482        - In 2D: Visualizes |p(x, y, ξ₀, η₀)| at fixed frequencies ξ₀ and η₀.
1483        - The color intensity represents the magnitude of the symbol, highlighting regions where the symbol is large or small.
1484        """
1485        if self.dim == 1:
1486            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1487            symbol_vals = self.p_func(X, XI) 
1488            plt.pcolormesh(X, XI, np.abs(symbol_vals), shading='auto')
1489            plt.colorbar(label='|Symbol|')
1490            plt.xlabel('x')
1491            plt.ylabel('ξ')
1492            plt.title('Symbol Amplitude |p(x, ξ)|')
1493            plt.show()
1494        elif self.dim == 2:
1495            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
1496            XI = np.full_like(X, xi0)
1497            ETA = np.full_like(Y, eta0)
1498            symbol_vals = self.p_func(X, Y, XI, ETA)
1499            plt.pcolormesh(X, Y, np.abs(symbol_vals), shading='auto')
1500            plt.colorbar(label='|Symbol|')
1501            plt.xlabel('x')
1502            plt.ylabel('y')
1503            plt.title(f'Symbol Amplitude at ξ={xi0}, η={eta0}')
1504            plt.show()
1505
1506    def visualize_phase(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1507        """
1508        Plot the phase (argument) of the pseudodifferential operator's symbol p(x, ξ) or p(x, y, ξ, η).
1509
1510        This visualization helps in understanding the oscillatory behavior and regularity 
1511        properties of the operator in phase space. The phase is displayed modulo 2π using 
1512        a cyclic colormap ('twilight') to emphasize its periodic nature.
1513
1514        Parameters
1515        ----------
1516        x_grid : ndarray
1517            1D array of spatial coordinates (x).
1518        xi_grid : ndarray
1519            1D array of frequency coordinates (ξ).
1520        y_grid : ndarray, optional
1521            2D spatial grid for y-coordinate (in 2D problems). Default is None.
1522        eta_grid : ndarray, optional
1523            2D frequency grid for η (in 2D problems). Not used directly but kept for API consistency.
1524        xi0 : float, optional
1525            Fixed value of ξ for slicing in 2D visualization. Default is 0.0.
1526        eta0 : float, optional
1527            Fixed value of η for slicing in 2D visualization. Default is 0.0.
1528
1529        Notes:
1530        - In 1D: Displays arg(p(x, ξ)) over the (x, ξ) phase plane.
1531        - In 2D: Displays arg(p(x, y, ξ₀, η₀)) for fixed frequency values (ξ₀, η₀).
1532        - Uses plt.pcolormesh with 'twilight' colormap to represent angles from -π to π.
1533
1534        Raises:
1535        - NotImplementedError: If the spatial dimension is not 1D or 2D.
1536        """
1537        if self.dim == 1:
1538            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1539            symbol_vals = self.p_func(X, XI) 
1540            plt.pcolormesh(X, XI, np.angle(symbol_vals), shading='auto', cmap='twilight')
1541            plt.colorbar(label='arg(Symbol) [rad]')
1542            plt.xlabel('x')
1543            plt.ylabel('ξ')
1544            plt.title('Phase Portrait (arg p(x, ξ))')
1545            plt.show()
1546        elif self.dim == 2:
1547            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
1548            XI = np.full_like(X, xi0)
1549            ETA = np.full_like(Y, eta0)
1550            symbol_vals = self.p_func(X, Y, XI, ETA)
1551            plt.pcolormesh(X, Y, np.angle(symbol_vals), shading='auto', cmap='twilight')
1552            plt.colorbar(label='arg(Symbol) [rad]')
1553            plt.xlabel('x')
1554            plt.ylabel('y')
1555            plt.title(f'Phase Portrait at ξ={xi0}, η={eta0}')
1556            plt.show()
1557            
1558    def visualize_characteristic_set(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0, levels=[1e-1]):
1559        """
1560        Visualize the characteristic set of the pseudo-differential symbol, defined as the approximate zero set p(x, ξ) ≈ 0.
1561    
1562        In microlocal analysis, the characteristic set is the locus of points in phase space (x, ξ) where the symbol p(x, ξ) vanishes,
1563        playing a key role in understanding propagation of singularities.
1564    
1565        Parameters
1566        ----------
1567        x_grid : ndarray
1568            Spatial grid values (1D array) for plotting in 1D or evaluation point in 2D.
1569        xi_grid : ndarray
1570            Frequency variable grid values (1D array) used to construct the frequency domain.
1571        x0 : float, optional
1572            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific x position.
1573        y0 : float, optional
1574            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific y position.
1575    
1576        Notes
1577        -----
1578        - For 1D, this method plots the contour of |p(x, ξ)| = ε with ε = 1e-5 over the (x, ξ) plane.
1579        - For 2D, it evaluates the symbol at fixed (x₀, y₀) and plots the characteristic set in the (ξ, η) frequency plane.
1580        - This visualization helps identify directions of degeneracy or hypoellipticity of the operator.
1581    
1582        Raises
1583        ------
1584        NotImplementedError
1585            If called on a solver with dimensionality other than 1D or 2D.
1586    
1587        Displays
1588        ------
1589        A matplotlib contour plot showing either:
1590            - The characteristic curve in the (x, ξ) phase plane (1D),
1591            - The characteristic surface slice in the (ξ, η) frequency plane at (x₀, y₀) (2D).
1592        """
1593        if self.dim == 1:
1594            x_grid = np.asarray(x_grid)
1595            xi_grid = np.asarray(xi_grid)
1596            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1597            symbol_vals = self.p_func(X, XI) 
1598            plt.contour(X, XI, np.abs(symbol_vals), levels=levels, colors='red')
1599            plt.xlabel('x')
1600            plt.ylabel('ξ')
1601            plt.title('Characteristic Set (p(x, ξ) ≈ 0)')
1602            plt.grid(True)
1603            plt.show()
1604        elif self.dim == 2:
1605            if eta_grid is None:
1606                raise ValueError("eta_grid must be provided for 2D visualization.")
1607            xi_grid = np.asarray(xi_grid)
1608            eta_grid = np.asarray(eta_grid)
1609            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
1610            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1611            plt.contour(xi_grid, eta_grid, np.abs(symbol_vals), levels=levels, colors='red')
1612            plt.xlabel('ξ')
1613            plt.ylabel('η')
1614            plt.title(f'Characteristic Set at x={x0}, y={y0}')
1615            plt.grid(True)
1616            plt.show()
1617        else:
1618            raise NotImplementedError("Only 1D/2D characteristic sets supported.")
1619
1620    def visualize_characteristic_gradient(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0):
1621        """
1622        Visualize the norm of the gradient of the symbol in phase space.
1623        
1624        This method computes the magnitude of the gradient |∇p| of a pseudo-differential 
1625        symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. The resulting colormap reveals 
1626        regions where the symbol varies rapidly or remains nearly stationary, 
1627        which is particularly useful for analyzing characteristic sets.
1628        
1629        Parameters
1630        ----------
1631        x_grid : numpy.ndarray
1632            1D array of spatial coordinates for the x-direction.
1633        xi_grid : numpy.ndarray
1634            1D array of frequency coordinates (ξ).
1635        y_grid : numpy.ndarray, optional
1636            1D array of spatial coordinates for the y-direction (used in 2D mode). Default is None.
1637        eta_grid : numpy.ndarray, optional
1638            1D array of frequency coordinates (η) for the 2D case. Default is None.
1639        x0 : float, optional
1640            Fixed x-coordinate for evaluating the symbol in 2D. Default is 0.0.
1641        y0 : float, optional
1642            Fixed y-coordinate for evaluating the symbol in 2D. Default is 0.0.
1643        
1644        Returns
1645        -------
1646        None
1647            Displays a 2D colormap of |∇p| over the relevant phase-space domain.
1648        
1649        Notes
1650        -----
1651        - In 1D, the full gradient ∇p = (∂ₓp, ∂ξp) is computed over the (x, ξ) grid.
1652        - In 2D, the gradient ∇p = (∂ξp, ∂ηp) is computed at a fixed spatial point (x₀, y₀) over the (ξ, η) grid.
1653        - Numerical differentiation is performed using `np.gradient`.
1654        - High values of |∇p| indicate rapid variation of the symbol, while low values typically suggest characteristic regions.
1655        """
1656        if self.dim == 1:
1657            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1658            symbol_vals = self.p_func(X, XI)
1659            grad_x = np.gradient(symbol_vals, axis=0)
1660            grad_xi = np.gradient(symbol_vals, axis=1)
1661            grad_norm = np.sqrt(grad_x**2 + grad_xi**2)
1662            plt.pcolormesh(X, XI, grad_norm, cmap='inferno', shading='auto')
1663            plt.colorbar(label='|∇p|')
1664            plt.xlabel('x')
1665            plt.ylabel('ξ')
1666            plt.title('Gradient Norm (High Near Zeros)')
1667            plt.grid(True)
1668            plt.show()
1669        elif self.dim == 2:
1670            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
1671            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1672            grad_xi = np.gradient(symbol_vals, axis=0)
1673            grad_eta = np.gradient(symbol_vals, axis=1)
1674            grad_norm = np.sqrt(np.abs(grad_xi)**2 + np.abs(grad_eta)**2)
1675            plt.pcolormesh(xi_grid, eta_grid, grad_norm, cmap='inferno', shading='auto')
1676            plt.colorbar(label='|∇p|')
1677            plt.xlabel('ξ')
1678            plt.ylabel('η')
1679            plt.title(f'Gradient Norm at x={x0}, y={y0}')
1680            plt.grid(True)
1681            plt.show()
1682
1683    def plot_hamiltonian_flow(self, x0=0.0, xi0=5.0, y0=0.0, eta0=0.0, tmax=1.0, n_steps=100, show_field=True):
1684        """
1685        Integrate and plot the Hamiltonian trajectories of the symbol in phase space.
1686
1687        This method numerically integrates the Hamiltonian vector field derived from 
1688        the operator's symbol to visualize how singularities propagate under the flow. 
1689        It supports both 1D and 2D problems.
1690
1691        Parameters
1692        ----------
1693        x0, xi0 : float
1694            Initial position and frequency (momentum) in 1D.
1695        y0, eta0 : float, optional
1696            Initial position and frequency in 2D; defaults to zero.
1697        tmax : float
1698            Final integration time for the ODE solver.
1699        n_steps : int
1700            Number of time steps used in the integration.
1701
1702        Notes
1703        -----
1704        - The Hamiltonian vector field is obtained from the symplectic flow of the symbol.
1705        - If the field is complex-valued, only its real part is used for integration.
1706        - In 1D, the trajectory is plotted in (x, ξ) phase space.
1707        - In 2D, the spatial trajectory (x(t), y(t)) is shown along with instantaneous 
1708          momentum vectors (ξ(t), η(t)) using a quiver plot.
1709
1710        Raises
1711        ------
1712        NotImplementedError
1713            If the spatial dimension is not 1D or 2D.
1714
1715        Displays
1716        --------
1717        matplotlib plot
1718            Phase space trajectory(ies) showing the evolution of position and momentum 
1719            under the Hamiltonian dynamics.
1720        """
1721        def make_real(expr):
1722            from sympy import re, simplify
1723            expr = expr.doit(deep=True)
1724            return simplify(re(expr))
1725    
1726        H = self.symplectic_flow()
1727    
1728        if any(im(H[k]) != 0 for k in H):
1729            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
1730    
1731        if self.dim == 1:
1732            x, = self.vars_x
1733            xi = symbols('xi', real=True)
1734    
1735            dxdt_expr = make_real(H['dx/dt'])
1736            dxidt_expr = make_real(H['dxi/dt'])
1737    
1738            dxdt = lambdify((x, xi), dxdt_expr, 'numpy')
1739            dxidt = lambdify((x, xi), dxidt_expr, 'numpy')
1740    
1741            def hamilton(t, Y):
1742                x, xi = Y
1743                return [dxdt(x, xi), dxidt(x, xi)]
1744    
1745            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0], t_eval=np.linspace(0, tmax, n_steps))
1746
1747            if sol.status != 0:
1748                print(f"⚠️ Integration warning: {sol.message}")
1749            
1750            n_points = sol.y.shape[1]
1751            if n_points < n_steps:
1752                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
1753                n_steps = n_points
1754
1755            x_vals, xi_vals = sol.y
1756    
1757            plt.plot(x_vals, xi_vals)
1758            plt.xlabel("x")
1759            plt.ylabel("ξ")
1760            plt.title("Hamiltonian Flow in Phase Space (1D)")
1761            plt.grid(True)
1762            plt.show()
1763    
1764        elif self.dim == 2:
1765            x, y = self.vars_x
1766            xi, eta = symbols('xi eta', real=True)
1767    
1768            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
1769            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
1770            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
1771            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
1772    
1773            def hamilton(t, Y):
1774                x, y, xi, eta = Y
1775                return [
1776                    dxdt(x, y, xi, eta),
1777                    dydt(x, y, xi, eta),
1778                    dxidt(x, y, xi, eta),
1779                    detadt(x, y, xi, eta)
1780                ]
1781    
1782            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0], t_eval=np.linspace(0, tmax, n_steps))
1783
1784            if sol.status != 0:
1785                print(f"⚠️ Integration warning: {sol.message}")
1786            
1787            n_points = sol.y.shape[1]
1788            if n_points < n_steps:
1789                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
1790                n_steps = n_points
1791
1792            x_vals, y_vals, xi_vals, eta_vals = sol.y
1793    
1794            plt.plot(x_vals, y_vals, label='Position')
1795            plt.quiver(x_vals, y_vals, xi_vals, eta_vals, scale=20, width=0.003, alpha=0.5, color='r')
1796            
1797            # Vector field of the flow (optional)
1798            if show_field:
1799                X, Y = np.meshgrid(np.linspace(min(x_vals), max(x_vals), 20),
1800                                   np.linspace(min(y_vals), max(y_vals), 20))
1801                XI, ETA = xi0 * np.ones_like(X), eta0 * np.ones_like(Y)
1802                U = dxdt(X, Y, XI, ETA)
1803                V = dydt(X, Y, XI, ETA)
1804                plt.quiver(X, Y, U, V, color='gray', alpha=0.2, scale=30, width=0.002)
1805
1806            plt.xlabel("x")
1807            plt.ylabel("y")
1808            plt.title("Hamiltonian Flow in Phase Space (2D)")
1809            plt.legend()
1810            plt.grid(True)
1811            plt.axis('equal')
1812            plt.show()
1813
1814    def plot_symplectic_vector_field(self, xlim=(-2, 2), klim=(-5, 5), density=30):
1815        """
1816        Visualize the symplectic vector field (Hamiltonian vector field) associated with the operator's symbol.
1817
1818        The plotted vector field corresponds to (∂_ξ p, -∂_x p), where p(x, ξ) is the principal symbol 
1819        of the pseudo-differential operator. This field governs the bicharacteristic flow in phase space.
1820
1821        Parameters
1822        ----------
1823        xlim : tuple of float
1824            Range for spatial variable x, as (x_min, x_max).
1825        klim : tuple of float
1826            Range for frequency variable ξ, as (ξ_min, ξ_max).
1827        density : int
1828            Number of grid points per axis for the visualization grid.
1829
1830        Raises
1831        ------
1832        NotImplementedError
1833            If called on a 2D operator (currently only 1D implementation available).
1834
1835        Notes
1836        -----
1837        - Only supports one-dimensional operators.
1838        - Uses symbolic differentiation to compute ∂_ξ p and ∂_x p.
1839        - Numerical evaluation is done via lambdify with NumPy backend.
1840        - Visualization uses matplotlib quiver plot to show vector directions.
1841        """
1842        x_vals = np.linspace(*xlim, density)
1843        xi_vals = np.linspace(*klim, density)
1844        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1845
1846        if self.dim != 1:
1847            raise NotImplementedError("Only 1D version implemented.")
1848
1849        x, = self.vars_x
1850        xi = symbols('xi', real=True)
1851        H = self.symplectic_flow()
1852        dxdt = lambdify((x, xi), simplify(H['dx/dt']), 'numpy')
1853        dxidt = lambdify((x, xi), simplify(H['dxi/dt']), 'numpy')
1854
1855        U = dxdt(X, XI)
1856        V = dxidt(X, XI)
1857
1858        plt.quiver(X, XI, U, V, scale=10, width=0.005)
1859        plt.xlabel('x')
1860        plt.ylabel(r'$\xi$')
1861        plt.title("Symplectic Vector Field (1D)")
1862        plt.grid(True)
1863        plt.show()
1864
1865    def visualize_micro_support(self, xlim=(-2, 2), klim=(-10, 10), threshold=1e-3, density=300):
1866        """
1867        Visualize the micro-support of the operator by plotting the inverse of the symbol magnitude 1 / |p(x, ξ)|.
1868    
1869        The micro-support provides insight into the singularities of a pseudo-differential operator 
1870        in phase space (x, ξ). Regions where |p(x, ξ)| is small correspond to large values in 1/|p(x, ξ)|,
1871        highlighting areas of significant operator influence or singularity.
1872    
1873        Parameters
1874        ----------
1875        xlim : tuple
1876            Spatial domain limits (x_min, x_max).
1877        klim : tuple
1878            Frequency domain limits (ξ_min, ξ_max).
1879        threshold : float
1880            Threshold below which |p(x, ξ)| is considered effectively zero; used for numerical stability.
1881        density : int
1882            Number of grid points along each axis for visualization resolution.
1883    
1884        Raises
1885        ------
1886        NotImplementedError
1887            If called on a solver with dimension greater than 1 (only 1D visualization is supported).
1888    
1889        Notes
1890        -----
1891        - This method evaluates the symbol p(x, ξ) over a grid and plots its reciprocal to emphasize 
1892          regions where the symbol is near zero.
1893        - A small constant (1e-10) is added to the denominator to avoid division by zero.
1894        - The resulting plot helps identify characteristic sets.
1895        """
1896        if self.dim != 1:
1897            raise NotImplementedError("Only 1D micro-support visualization implemented.")
1898
1899        x_vals = np.linspace(*xlim, density)
1900        xi_vals = np.linspace(*klim, density)
1901        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1902        Z = np.abs(self.p_func(X, XI))
1903
1904        plt.contourf(X, XI, 1 / (Z + 1e-10), levels=100, cmap='inferno')
1905        plt.colorbar(label=r'$1/|p(x,\xi)|$')
1906        plt.xlabel('x')
1907        plt.ylabel(r'$\xi$')
1908        plt.title("Micro-Support Estimate (1/|Symbol|)")
1909        plt.show()
1910
1911    def group_velocity_field(self, xlim=(-2, 2), klim=(-10, 10), density=30):
1912        """
1913        Plot the group velocity field ∇_ξ p(x, ξ) for 1D pseudo-differential operators.
1914
1915        The group velocity represents the speed at which waves of different frequencies propagate 
1916        in a dispersive medium. It is defined as the gradient of the symbol p(x, ξ) with respect 
1917        to the frequency variable ξ.
1918
1919        Parameters
1920        ----------
1921        xlim : tuple of float
1922            Spatial domain limits (x-axis).
1923        klim : tuple of float
1924            Frequency domain limits (ξ-axis).
1925        density : int
1926            Number of grid points per axis used for visualization.
1927
1928        Raises
1929        ------
1930        NotImplementedError
1931            If called on a 2D operator, since this visualization is only implemented for 1D.
1932
1933        Notes
1934        -----
1935        - This method visualizes the vector field (∂p/∂ξ) in phase space.
1936        - Used for analyzing wave propagation properties and dispersion relations.
1937        - Requires symbolic expression self.expr depending on x and ξ.
1938        """
1939        if self.dim != 1:
1940            raise NotImplementedError("Only 1D group velocity visualization implemented.")
1941
1942        x, = self.vars_x
1943        xi = symbols('xi', real=True)
1944        dp_dxi = diff(self.symbol, xi)
1945        grad_func = lambdify((x, xi), dp_dxi, 'numpy')
1946
1947        x_vals = np.linspace(*xlim, density)
1948        xi_vals = np.linspace(*klim, density)
1949        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1950        V = grad_func(X, XI)
1951
1952        plt.quiver(X, XI, np.ones_like(V), V, scale=10, width=0.004)
1953        plt.xlabel('x')
1954        plt.ylabel(r'$\xi$')
1955        plt.title("Group Velocity Field (1D)")
1956        plt.grid(True)
1957        plt.show()
1958
1959    def animate_singularity(self, xi0=5.0, eta0=0.0, x0=0.0, y0=0.0,
1960                            tmax=4.0, n_frames=100, projection=None):
1961        """
1962        Animate the propagation of a singularity under the Hamiltonian flow.
1963
1964        This method visualizes how a singularity (x₀, y₀, ξ₀, η₀) evolves in phase space 
1965        according to the Hamiltonian dynamics induced by the principal symbol of the operator.
1966        The animation integrates the Hamiltonian equations of motion and supports various projections:
1967        position (x-y), frequency (ξ-η), or mixed phase space coordinates.
1968
1969        Parameters
1970        ----------
1971        xi0, eta0 : float
1972            Initial frequency components (ξ₀, η₀).
1973        x0, y0 : float
1974            Initial spatial coordinates (x₀, y₀).
1975        tmax : float
1976            Total time of integration (final animation time).
1977        n_frames : int
1978            Number of frames in the resulting animation.
1979        projection : str or None
1980            Type of projection to display:
1981                - 'position' : x vs y (or x alone in 1D)
1982                - 'frequency': ξ vs η (or ξ alone in 1D)
1983                - 'phase'    : mixed coordinates like x vs ξ or x vs η
1984                If None, defaults to 'phase' in 1D and 'position' in 2D.
1985
1986        Returns
1987        -------
1988        matplotlib.animation.FuncAnimation
1989            Animation object that can be displayed interactively in Jupyter notebooks or saved as a video.
1990
1991        Notes
1992        -----
1993        - In 1D, only one spatial and one frequency variable are used.
1994        - Complex-valued Hamiltonian fields are truncated to their real parts for integration.
1995        - Trajectories are shown with both instantaneous position (dot) and full path (dashed line).
1996        """
1997        rc('animation', html='jshtml')
1998    
1999        def make_real(expr):
2000            from sympy import re, simplify
2001            expr = expr.doit(deep=True)
2002            return simplify(re(expr))
2003  
2004        H = self.symplectic_flow()
2005
2006        H = {k: v.doit(deep=True) for k, v in H.items()}
2007
2008        print("H = ", H)
2009    
2010        if any(im(H[k]) != 0 for k in H):
2011            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
2012    
2013        if self.dim == 1:
2014            x, = self.vars_x
2015            xi = symbols('xi', real=True)
2016    
2017            dxdt = lambdify((x, xi), make_real(H['dx/dt']), 'numpy')
2018            dxidt = lambdify((x, xi), make_real(H['dxi/dt']), 'numpy')
2019    
2020            def hamilton(t, Y):
2021                x, xi = Y
2022                return [dxdt(x, xi), dxidt(x, xi)]
2023    
2024            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0],
2025                            t_eval=np.linspace(0, tmax, n_frames))
2026            
2027            if sol.status != 0:
2028                print(f"⚠️ Integration warning: {sol.message}")
2029            
2030            n_points = sol.y.shape[1]
2031            if n_points < n_frames:
2032                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2033                n_frames = n_points
2034
2035            x_vals, xi_vals = sol.y
2036    
2037            if projection is None:
2038                projection = 'phase'
2039    
2040            fig, ax = plt.subplots()
2041            point, = ax.plot([], [], 'ro')
2042            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2043    
2044            if projection == 'phase':
2045                ax.set_xlabel('x')
2046                ax.set_ylabel(r'$\xi$')
2047                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2048                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2049    
2050                def update(i):
2051                    point.set_data([x_vals[i]], [xi_vals[i]])
2052                    traj.set_data(x_vals[:i+1], xi_vals[:i+1])
2053                    return point, traj
2054    
2055            elif projection == 'position':
2056                ax.set_xlabel('x')
2057                ax.set_ylabel('x')
2058                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2059                ax.set_ylim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2060    
2061                def update(i):
2062                    point.set_data([x_vals[i]], [x_vals[i]])
2063                    traj.set_data(x_vals[:i+1], x_vals[:i+1])
2064                    return point, traj
2065    
2066            elif projection == 'frequency':
2067                ax.set_xlabel(r'$\xi$')
2068                ax.set_ylabel(r'$\xi$')
2069                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2070                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2071    
2072                def update(i):
2073                    point.set_data([xi_vals[i]], [xi_vals[i]])
2074                    traj.set_data(xi_vals[:i+1], xi_vals[:i+1])
2075                    return point, traj
2076    
2077            else:
2078                raise ValueError("Invalid projection mode")
2079    
2080            ax.set_title(f"1D Singularity Flow ({projection})")
2081            ax.grid(True)
2082            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2083            plt.close(fig)
2084            return ani
2085    
2086        elif self.dim == 2:
2087            x, y = self.vars_x
2088            xi, eta = symbols('xi eta', real=True)
2089    
2090            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
2091            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
2092            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
2093            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
2094    
2095            def hamilton(t, Y):
2096                x, y, xi, eta = Y
2097                return [
2098                    dxdt(x, y, xi, eta),
2099                    dydt(x, y, xi, eta),
2100                    dxidt(x, y, xi, eta),
2101                    detadt(x, y, xi, eta)
2102                ]
2103    
2104            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0],
2105                            t_eval=np.linspace(0, tmax, n_frames))
2106
2107            if sol.status != 0:
2108                print(f"⚠️ Integration warning: {sol.message}")
2109            
2110            n_points = sol.y.shape[1]
2111            if n_points < n_frames:
2112                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2113                n_frames = n_points
2114                
2115            x_vals, y_vals, xi_vals, eta_vals = sol.y
2116    
2117            if projection is None:
2118                projection = 'position'
2119    
2120            fig, ax = plt.subplots()
2121            point, = ax.plot([], [], 'ro')
2122            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2123    
2124            if projection == 'position':
2125                ax.set_xlabel('x')
2126                ax.set_ylabel('y')
2127                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2128                ax.set_ylim(np.min(y_vals) - 1, np.max(y_vals) + 1)
2129    
2130                def update(i):
2131                    point.set_data([x_vals[i]], [y_vals[i]])
2132                    traj.set_data(x_vals[:i+1], y_vals[:i+1])
2133                    return point, traj
2134    
2135            elif projection == 'frequency':
2136                ax.set_xlabel(r'$\xi$')
2137                ax.set_ylabel(r'$\eta$')
2138                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2139                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2140    
2141                def update(i):
2142                    point.set_data([xi_vals[i]], [eta_vals[i]])
2143                    traj.set_data(xi_vals[:i+1], eta_vals[:i+1])
2144                    return point, traj
2145    
2146            elif projection == 'phase':
2147                ax.set_xlabel('x')
2148                ax.set_ylabel(r'$\eta$')
2149                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2150                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2151    
2152                def update(i):
2153                    point.set_data([x_vals[i]], [eta_vals[i]])
2154                    traj.set_data(x_vals[:i+1], eta_vals[:i+1])
2155                    return point, traj
2156    
2157            else:
2158                raise ValueError("Invalid projection mode")
2159    
2160            ax.set_title(f"2D Singularity Flow ({projection})")
2161            ax.grid(True)
2162            ax.axis('equal')
2163            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2164            plt.close(fig)
2165            return ani
2166
2167    def interactive_symbol_analysis(pseudo_op,
2168                                    xlim=(-2, 2), ylim=(-2, 2),
2169                                    xi_range=(0.1, 5), eta_range=(-5, 5),
2170                                    density=100):
2171        """
2172        Launch an interactive dashboard for symbol exploration using ipywidgets.
2173    
2174        This function provides a user-friendly interface to visualize various aspects of the pseudo-differential operator's symbol.
2175        It supports multiple visualization modes in both 1D and 2D, including group velocity fields, micro-support estimates,
2176        symplectic vector fields, symbol amplitude/phase, cotangent fiber structure, characteristic sets and Hamiltonian flows.
2177    
2178        Parameters
2179        ----------
2180        pseudo_op : PseudoDifferentialOperator
2181            The pseudo-differential operator whose symbol is to be analyzed interactively.
2182        xlim, ylim : tuple of float
2183            Spatial domain limits along x and y axes respectively.
2184        xi_range, eta_range : tuple
2185            Frequency domain limits along ξ and η axes respectively.
2186        density : int
2187            Number of points per axis used to construct the evaluation grid. Controls resolution.
2188    
2189        Notes
2190        -----
2191        - In 1D mode, sliders control the fixed frequency (ξ₀) and spatial position (x₀).
2192        - In 2D mode, additional sliders control the second frequency component (η₀) and second spatial coordinate (y₀).
2193        - Visualization updates dynamically as parameters are adjusted via sliders or dropdown menus.
2194        - Supported visualization modes:
2195            'Symbol Amplitude'           : |p(x,ξ)| or |p(x,y,ξ,η)|
2196            'Symbol Phase'               : arg(p(x,ξ)) or similar in 2D
2197            'Micro-Support (1/|p|)'      : Reciprocal of symbol magnitude
2198            'Cotangent Fiber'            : Structure of symbol over frequency space at fixed x
2199            'Characteristic Set'         : Zero set approximation {p ≈ 0}
2200            'Characteristic Gradient'    : |∇p(x, ξ)| or |∇p(x₀, y₀, ξ, η)|
2201            'Group Velocity Field'       : ∇_ξ p(x,ξ) or ∇_{ξ,η} p(x,y,ξ,η)
2202            'Symplectic Vector Field'    : (∇_ξ p, -∇_x p) or similar in 2D
2203            'Hamiltonian Flow'           : Trajectories generated by the Hamiltonian vector field
2204    
2205        Raises
2206        ------
2207        NotImplementedError
2208            If the spatial dimension is not 1D or 2D.
2209    
2210        Prints
2211        ------
2212        Interactive matplotlib figures with dynamic updates based on widget inputs.
2213        """
2214        dim = pseudo_op.dim
2215        expr = pseudo_op.expr
2216        vars_x = pseudo_op.vars_x
2217    
2218        mode_selector_1D = Dropdown(
2219            options=[
2220                'Symbol Amplitude',
2221                'Symbol Phase',
2222                'Micro-Support (1/|p|)',
2223                'Cotangent Fiber',
2224                'Characteristic Set',
2225                'Characteristic Gradient',
2226                'Group Velocity Field',
2227                'Symplectic Vector Field',
2228                'Hamiltonian Flow',
2229            ],
2230            value='Symbol Amplitude',
2231            description='Mode:'
2232        )
2233
2234        mode_selector_2D = Dropdown(
2235            options=[
2236                'Symbol Amplitude',
2237                'Symbol Phase',
2238                'Micro-Support (1/|p|)',
2239                'Cotangent Fiber',
2240                'Characteristic Set',
2241                'Characteristic Gradient',
2242                'Symplectic Vector Field',
2243                'Hamiltonian Flow',
2244            ],
2245            value='Symbol Amplitude',
2246            description='Mode:'
2247        )
2248    
2249        x_vals = np.linspace(*xlim, density)
2250        if dim == 2:
2251            y_vals = np.linspace(*ylim, density)
2252    
2253        if dim == 1:
2254            x, = vars_x
2255            xi = symbols('xi', real=True)
2256            grad_func = lambdify((x, xi), diff(expr, xi), 'numpy')
2257            symplectic_func = lambdify((x, xi), [diff(expr, xi), -diff(expr, x)], 'numpy')
2258            symbol_func = lambdify((x, xi), expr, 'numpy')
2259
2260            xi_slider = FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2261            x_slider = FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2262    
2263            def plot_1d(mode, xi0, x0):
2264                X = x_vals[:, None]
2265    
2266                if mode == 'Group Velocity Field':
2267                    V = grad_func(X, xi0)
2268                    plt.quiver(X, V, np.ones_like(V), V, scale=10, width=0.004)
2269                    plt.xlabel('x')
2270                    plt.title(f'Group Velocity Field at ξ={xi0:.2f}')
2271    
2272                elif mode == 'Micro-Support (1/|p|)':
2273                    Z = 1 / (np.abs(symbol_func(X, xi0)) + 1e-10)
2274                    plt.plot(x_vals, Z)
2275                    plt.xlabel('x')
2276                    plt.title(f'Micro-Support (1/|p|) at ξ={xi0:.2f}')
2277    
2278                elif mode == 'Symplectic Vector Field':
2279                    U, V = symplectic_func(X, xi0)
2280                    plt.quiver(X, V, U, V, scale=10, width=0.004)
2281                    plt.xlabel('x')
2282                    plt.title(f'Symplectic Field at ξ={xi0:.2f}')
2283    
2284                elif mode == 'Symbol Amplitude':
2285                    Z = np.abs(symbol_func(X, xi0))
2286                    plt.plot(x_vals, Z)
2287                    plt.xlabel('x')
2288                    plt.title(f'Symbol Amplitude |p(x,ξ)| at ξ={xi0:.2f}')
2289    
2290                elif mode == 'Symbol Phase':
2291                    Z = np.angle(symbol_func(X, xi0))
2292                    plt.plot(x_vals, Z)
2293                    plt.xlabel('x')
2294                    plt.title(f'Symbol Phase arg(p(x,ξ)) at ξ={xi0:.2f}')
2295    
2296                elif mode == 'Cotangent Fiber':
2297                    pseudo_op.visualize_fiber(x_vals, np.linspace(*xi_range, density), x0=x0)
2298    
2299                elif mode == 'Characteristic Set':
2300                    pseudo_op.visualize_characteristic_set(x_vals, np.linspace(*xi_range, density), x0=x0)
2301    
2302                elif mode == 'Characteristic Gradient':
2303                    pseudo_op.visualize_characteristic_gradient(x_vals, np.linspace(*xi_range, density), x0=x0)
2304    
2305                elif mode == 'Hamiltonian Flow':
2306                    pseudo_op.plot_hamiltonian_flow(x0=x0, xi0=xi0)
2307    
2308            # --- Dynamic container for sliders ---
2309            controls_box = VBox([mode_selector_1D, xi_slider, x_slider])
2310            # --- Function to adjust visible sliders based on mode ---
2311            def update_controls(change):
2312                mode = change['new']
2313                # modes that depend only on xi and eta
2314                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)',
2315                            'Group Velocity Field', 'Symplectic Vector Field']:
2316                    controls_box.children = [mode_selector_1D, xi_slider]
2317                # modes that require xi and x
2318                elif mode in ['Hamiltonian Flow']:
2319                    controls_box.children = [mode_selector_1D, xi_slider, x_slider]
2320                # modes that require nothing
2321                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2322                    controls_box.children = [mode_selector_1D]
2323            mode_selector_1D.observe(update_controls, names='value')
2324            update_controls({'new': mode_selector_1D.value}) 
2325            # --- Interactive binding ---
2326            out = interactive_output(plot_1d, {'mode': mode_selector_1D, 'xi0': xi_slider, 'x0': x_slider})
2327            display(VBox([controls_box, out]))
2328
2329        elif dim == 2:
2330            x, y = vars_x
2331            xi, eta = symbols('xi eta', real=True)
2332            symplectic_func = lambdify((x, y, xi, eta), [diff(expr, xi), diff(expr, eta)], 'numpy')
2333            symbol_func = lambdify((x, y, xi, eta), expr, 'numpy')
2334
2335            xi_slider=FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2336            eta_slider=FloatSlider(min=eta_range[0], max=eta_range[1], step=0.1, value=1.0, description='η₀')
2337            x_slider=FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2338            y_slider=FloatSlider(min=ylim[0], max=ylim[1], step=0.1, value=0.0, description='y₀')
2339    
2340            def plot_2d(mode, xi0, eta0, x0, y0):
2341                X, Y = np.meshgrid(x_vals, y_vals, indexing='ij')
2342    
2343                if mode == 'Micro-Support (1/|p|)':
2344                    Z = 1 / (np.abs(symbol_func(X, Y, xi0, eta0)) + 1e-10)
2345                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='inferno')
2346                    plt.colorbar(label='1/|p|')
2347                    plt.xlabel('x')
2348                    plt.ylabel('y')
2349                    plt.title(f'Micro-Support at ξ={xi0:.2f}, η={eta0:.2f}')
2350    
2351                elif mode == 'Symplectic Vector Field':
2352                    U, V = symplectic_func(X, Y, xi0, eta0)
2353                    plt.quiver(X, Y, U, V, scale=10, width=0.004)
2354                    plt.xlabel('x')
2355                    plt.ylabel('y')
2356                    plt.title(f'Symplectic Field at ξ={xi0:.2f}, η={eta0:.2f}')
2357    
2358                elif mode == 'Symbol Amplitude':
2359                    Z = np.abs(symbol_func(X, Y, xi0, eta0))
2360                    plt.pcolormesh(X, Y, Z, shading='auto')
2361                    plt.colorbar(label='|p(x,y,ξ,η)|')
2362                    plt.xlabel('x')
2363                    plt.ylabel('y')
2364                    plt.title(f'Symbol Amplitude at ξ={xi0:.2f}, η={eta0:.2f}')
2365    
2366                elif mode == 'Symbol Phase':
2367                    Z = np.angle(symbol_func(X, Y, xi0, eta0))
2368                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='twilight')
2369                    plt.colorbar(label='arg(p)')
2370                    plt.xlabel('x')
2371                    plt.ylabel('y')
2372                    plt.title(f'Symbol Phase at ξ={xi0:.2f}, η={eta0:.2f}')
2373    
2374                elif mode == 'Cotangent Fiber':
2375                    pseudo_op.visualize_fiber(np.linspace(*xi_range, density), np.linspace(*eta_range, density),
2376                                              x0=x0, y0=y0)
2377    
2378                elif mode == 'Characteristic Set':
2379                    pseudo_op.visualize_characteristic_set(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2380                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2381    
2382                elif mode == 'Characteristic Gradient':
2383                    pseudo_op.visualize_characteristic_gradient(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2384                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2385    
2386                elif mode == 'Hamiltonian Flow':
2387                    pseudo_op.plot_hamiltonian_flow(x0=x0, y0=y0, xi0=xi0, eta0=eta0)
2388                    
2389            # --- Dynamic container for sliders ---
2390            controls_box = VBox([mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider])
2391            # --- Function to adjust visible sliders based on mode ---
2392            def update_controls(change):
2393                mode = change['new']
2394                # modes that depend only on xi
2395                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)', 'Symplectic Vector Field']:
2396                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider]
2397                # modes that require xi, eta, x and y
2398                elif mode in ['Hamiltonian Flow']:
2399                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider]
2400                # modes that require x and y
2401                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2402                    controls_box.children = [mode_selector_2D, x_slider, y_slider]
2403            mode_selector_2D.observe(update_controls, names='value')
2404            update_controls({'new': mode_selector_2D.value}) 
2405            # --- Interactive binding ---
2406            out = interactive_output(plot_2d, {'mode': mode_selector_2D, 'xi0': xi_slider, 'eta0': eta_slider, 'x0': x_slider, 'y0': y_slider})
2407            display(VBox([controls_box, out]))

Pseudo-differential operator with dynamic symbol evaluation on spatial grids. Supports both 1D and 2D operators, and can be defined explicitly (symbol mode) or extracted automatically from symbolic equations (auto mode).

Parameters

expr : sympy expression Symbolic expression representing the pseudo-differential symbol. vars_x : list of sympy symbols Spatial variables (e.g., [x] for 1D, [x, y] for 2D). var_u : sympy function, optional Function u(x, t) used in auto mode to extract the operator symbol. mode : str, {'symbol', 'auto'} - 'symbol': directly uses expr as the operator symbol. - 'auto': computes the symbol automatically by applying expr to exp(i x ξ).

Attributes

dim : int Spatial dimension (1 or 2). fft, ifft : callable Fast Fourier transform and inverse (scipy.fft or scipy.fft2). p_func : callable Evaluated symbol function ready for numerical use.

Notes

  • In 'symbol' mode, expr should be expressed in terms of spatial variables and frequency variables (ξ, η).
  • In 'auto' mode, the symbol is derived by applying the differential expression to a complex exponential.
  • Frequency variables are internally named 'xi' and 'eta' for consistency.
  • Uses numpy for numerical evaluation and scipy.fft for FFT operations.

Examples

>>> # Example 1: 1D Laplacian operator (symbol mode)
>>> from sympy import symbols
>>> x, xi = symbols('x xi', real=True)
>>> op = PseudoDifferentialOperator(expr=xi**2, vars_x=[x], mode='symbol')
>>> # Example 2: 1D transport operator (auto mode)
>>> from sympy import Function
>>> u = Function('u')
>>> expr = u(x).diff(x)
>>> op = PseudoDifferentialOperator(expr=expr, vars_x=[x], var_u=u(x), mode='auto')
PseudoDifferentialOperator(expr, vars_x, var_u=None, mode='symbol')
 73    def __init__(self, expr, vars_x, var_u=None, mode='symbol'):
 74        self.dim = len(vars_x)
 75        self.mode = mode
 76        self.symbol_cached = None
 77        self.expr = expr
 78        self.vars_x = vars_x
 79
 80        if self.dim == 1:
 81            x, = vars_x
 82            xi_internal = symbols('xi', real=True)
 83            expr = expr.subs(symbols('xi', real=True), xi_internal)
 84            self.fft = partial(fft, workers=FFT_WORKERS)
 85            self.ifft = partial(ifft, workers=FFT_WORKERS)
 86
 87            if mode == 'symbol':
 88                self.p_func = lambdify((x, xi_internal), expr, 'numpy')
 89                self.symbol = expr
 90            elif mode == 'auto':
 91                if var_u is None:
 92                    raise ValueError("var_u must be provided in mode='auto'")
 93                exp_i = exp(I * x * xi_internal)
 94                P_ei = expr.subs(var_u, exp_i)
 95                symbol = simplify(P_ei / exp_i)
 96                symbol = expand(symbol)
 97                self.symbol = symbol
 98                self.p_func = lambdify((x, xi_internal), symbol, 'numpy')
 99            else:
100                raise ValueError("mode must be 'auto' or 'symbol'")
101
102        elif self.dim == 2:
103            x, y = vars_x
104            xi_internal, eta_internal = symbols('xi eta', real=True)
105            expr = expr.subs(symbols('xi', real=True), xi_internal)
106            expr = expr.subs(symbols('eta', real=True), eta_internal)
107            self.fft = partial(fft2, workers=FFT_WORKERS)
108            self.ifft = partial(ifft2, workers=FFT_WORKERS)
109
110            if mode == 'symbol':
111                self.symbol = expr
112                self.p_func = lambdify((x, y, xi_internal, eta_internal), expr, 'numpy')
113            elif mode == 'auto':
114                if var_u is None:
115                    raise ValueError("var_u must be provided in mode='auto'")
116                exp_i = exp(I * (x * xi_internal + y * eta_internal))
117                P_ei = expr.subs(var_u, exp_i)
118                symbol = simplify(P_ei / exp_i)
119                symbol = expand(symbol)
120                self.symbol = symbol
121                self.p_func = lambdify((x, y, xi_internal, eta_internal), symbol, 'numpy')
122            else:
123                raise ValueError("mode must be 'auto' or 'symbol'")
124
125        else:
126            raise NotImplementedError("Only 1D and 2D supported")
127
128        if mode == 'auto':
129            print("\nsymbol = ")
130            pprint(self.symbol, num_columns=NUM_COLS)
dim
mode
symbol_cached
expr
vars_x
def evaluate(self, X, Y, KX, KY, cache=True):
132    def evaluate(self, X, Y, KX, KY, cache=True):
133        """
134        Evaluate the pseudo-differential operator's symbol on a grid of spatial and frequency coordinates.
135
136        The method dynamically selects between 1D and 2D evaluation based on the spatial dimension.
137        If caching is enabled and a cached symbol exists, it returns the cached result to avoid recomputation.
138
139        Parameters
140        ----------
141        X, Y : ndarray
142            Spatial grid coordinates. In 1D, Y is ignored.
143        KX, KY : ndarray
144            Frequency grid coordinates. In 1D, KY is ignored.
145        cache : bool, default=True
146            If True, stores the computed symbol for reuse in subsequent calls to avoid redundant computation.
147
148        Returns
149        -------
150        ndarray
151            Evaluated symbol values over the input grid. Shape matches the input spatial/frequency grids.
152
153        Raises
154        ------
155        NotImplementedError
156            If the spatial dimension is not 1D or 2D.
157        """
158        if cache and self.symbol_cached is not None:
159            return self.symbol_cached
160
161        if self.dim == 1:
162            symbol = self.p_func(X, KX)
163        elif self.dim == 2:
164            symbol = self.p_func(X, Y, KX, KY)
165
166        if cache:
167            self.symbol_cached = symbol
168
169        return symbol

Evaluate the pseudo-differential operator's symbol on a grid of spatial and frequency coordinates.

The method dynamically selects between 1D and 2D evaluation based on the spatial dimension. If caching is enabled and a cached symbol exists, it returns the cached result to avoid recomputation.

Parameters

X, Y : ndarray Spatial grid coordinates. In 1D, Y is ignored. KX, KY : ndarray Frequency grid coordinates. In 1D, KY is ignored. cache : bool, default=True If True, stores the computed symbol for reuse in subsequent calls to avoid redundant computation.

Returns

ndarray Evaluated symbol values over the input grid. Shape matches the input spatial/frequency grids.

Raises

NotImplementedError If the spatial dimension is not 1D or 2D.

def clear_cache(self):
171    def clear_cache(self):
172        """
173        Clear cached symbol evaluations.
174        """        
175        self.symbol_cached = None

Clear cached symbol evaluations.

def apply( self, u, x_grid, kx, boundary_condition='periodic', y_grid=None, ky=None, dealiasing_mask=None, freq_window='gaussian', clamp=1000000.0, space_window=False):
177    def apply(self, u, x_grid, kx, boundary_condition='periodic', 
178              y_grid=None, ky=None, dealiasing_mask=None,
179              freq_window='gaussian', clamp=1e6, space_window=False):
180        """
181        Apply the pseudo-differential operator to the input field u.
182    
183        This method dispatches the application of the pseudo-differential operator based on:
184        
185        - Whether the symbol is spatially dependent (x/y)
186        - The boundary condition in use (periodic or dirichlet)
187    
188        Supported operations:
189        
190        - Constant-coefficient symbols: applied via Fourier multiplication.
191        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
192        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
193    
194        Dispatch Logic:\n
195        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
196        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
197        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
198        
199        Parameters
200        ----------
201        u : ndarray
202            Function to which the operator is applied
203        x_grid : ndarray
204            Spatial grid in x direction
205        kx : ndarray
206            Frequency grid in x direction
207        boundary_condition : str
208            'periodic' or 'dirichlet'
209        y_grid : ndarray, optional
210            Spatial grid in y direction (for 2D)
211        ky : ndarray, optional
212            Frequency grid in y direction (for 2D)
213        dealiasing_mask : ndarray, optional
214            Dealiasing mask
215        freq_window : str
216            Frequency windowing ('gaussian' or 'hann')
217        clamp : float
218            Clamp symbol values to [-clamp, clamp]
219        space_window : bool
220            Apply spatial windowing
221            
222        Returns
223        -------
224        ndarray
225            Result of applying the operator
226        """
227        # Check if symbol depends on spatial variables
228        is_spatial = self._is_spatial_dependent()
229        
230        # Case 1: Constant symbol with periodic BC (fast path)
231        if not is_spatial and boundary_condition == 'periodic':
232            return self._apply_constant_fft(u, x_grid, kx, y_grid, ky, dealiasing_mask)
233        
234        # Case 2: Spatial symbol with periodic BC
235        elif boundary_condition == 'periodic':
236            symbol_func = self._get_symbol_func()
237            return kohn_nirenberg_fft(
238                u_vals=u,
239                symbol_func=symbol_func,
240                x_grid=x_grid,
241                kx=kx,
242                fft_func=self.fft,
243                ifft_func=self.ifft,
244                dim=self.dim,
245                y_grid=y_grid,
246                ky=ky,
247                freq_window=freq_window,
248                clamp=clamp,
249                space_window=space_window
250            )
251        
252        # Case 3: Dirichlet BC (non-periodic)
253        elif boundary_condition == 'dirichlet':
254            symbol_func = self._get_symbol_func()
255            
256            if self.dim == 1:
257                return kohn_nirenberg_nonperiodic(
258                    u_vals=u,
259                    x_grid=x_grid,
260                    xi_grid=kx,
261                    symbol_func=symbol_func,
262                    freq_window=freq_window,
263                    clamp=clamp,
264                    space_window=space_window
265                )
266            elif self.dim == 2:
267                return kohn_nirenberg_nonperiodic(
268                    u_vals=u,
269                    x_grid=(x_grid, y_grid),
270                    xi_grid=(kx, ky),
271                    symbol_func=symbol_func,
272                    freq_window=freq_window,
273                    clamp=clamp,
274                    space_window=space_window
275                )
276        
277        else:
278            raise ValueError(f"Invalid boundary condition '{boundary_condition}'")

Apply the pseudo-differential operator to the input field u.

This method dispatches the application of the pseudo-differential operator based on:

  • Whether the symbol is spatially dependent (x/y)
  • The boundary condition in use (periodic or dirichlet)

Supported operations:

  • Constant-coefficient symbols: applied via Fourier multiplication.
  • Spatially varying symbols: applied via Kohn–Nirenberg quantization.
  • Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.

Dispatch Logic:

if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]

elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)

elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)

Parameters

u : ndarray Function to which the operator is applied x_grid : ndarray Spatial grid in x direction kx : ndarray Frequency grid in x direction boundary_condition : str 'periodic' or 'dirichlet' y_grid : ndarray, optional Spatial grid in y direction (for 2D) ky : ndarray, optional Frequency grid in y direction (for 2D) dealiasing_mask : ndarray, optional Dealiasing mask freq_window : str Frequency windowing ('gaussian' or 'hann') clamp : float Clamp symbol values to [-clamp, clamp] space_window : bool Apply spatial windowing

Returns

ndarray Result of applying the operator

def principal_symbol(self, order=1):
376    def principal_symbol(self, order=1):
377        """
378        Compute the leading homogeneous component of the pseudo-differential symbol.
379
380        This method extracts the principal part of the symbol, which is the dominant 
381        term under high-frequency asymptotics (|ξ| → ∞). The expansion is performed 
382        in polar coordinates for 2D symbols to maintain rotational symmetry, then 
383        converted back to Cartesian form.
384
385        Parameters
386        ----------
387        order : int
388            Order of the asymptotic expansion in powers of 1/ρ, where ρ = |ξ| in 1D 
389            or ρ = sqrt(ξ² + η²) in 2D. Only the leading-order term is returned.
390
391        Returns
392        -------
393        sympy.Expr
394            The principal symbol component, homogeneous of degree `m - order`, where 
395            `m` is the original symbol's order.
396
397        Notes:
398        - In 1D, uses direct series expansion in ξ.
399        - In 2D, expands in radial variable ρ while preserving angular dependence.
400        - Useful for microlocal analysis and constructing parametrices.
401        """
402
403        p = self.symbol
404        if self.dim == 1:
405            xi = symbols('xi', real=True, positive=True)
406            return simplify(series(p, xi, oo, n=order).removeO())
407        elif self.dim == 2:
408            xi, eta = symbols('xi eta', real=True, positive=True)
409            # Homogeneous radial expansion: we set (ξ, η) = ρ (cosθ, sinθ)
410            rho, theta = symbols('rho theta', real=True, positive=True)
411            p_rho = p.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
412            expansion = series(p_rho, rho, oo, n=order).removeO()
413            # Revert back to (ξ, η)
414            expansion_cart = expansion.subs({rho: sqrt(xi**2 + eta**2),
415                                             cos(theta): xi / sqrt(xi**2 + eta**2),
416                                             sin(theta): eta / sqrt(xi**2 + eta**2)})
417            return simplify(powdenest(expansion_cart, force=True))

Compute the leading homogeneous component of the pseudo-differential symbol.

This method extracts the principal part of the symbol, which is the dominant term under high-frequency asymptotics (|ξ| → ∞). The expansion is performed in polar coordinates for 2D symbols to maintain rotational symmetry, then converted back to Cartesian form.

Parameters

order : int Order of the asymptotic expansion in powers of 1/ρ, where ρ = |ξ| in 1D or ρ = sqrt(ξ² + η²) in 2D. Only the leading-order term is returned.

Returns

sympy.Expr The principal symbol component, homogeneous of degree m - order, where m is the original symbol's order.

Notes:

  • In 1D, uses direct series expansion in ξ.
  • In 2D, expands in radial variable ρ while preserving angular dependence.
  • Useful for microlocal analysis and constructing parametrices.
def is_homogeneous(self, tol=1e-10):
419    def is_homogeneous(self, tol=1e-10):
420        """
421        Check whether the symbol is homogeneous in the frequency variables.
422    
423        Returns
424        -------
425        (bool, Rational or float or None)
426            Tuple (is_homogeneous, degree) where:
427            - is_homogeneous: True if the symbol satisfies p(λξ, λη) = λ^m * p(ξ, η)
428            - degree: the detected degree m if homogeneous, or None
429        """
430        from sympy import symbols, simplify, expand, Eq
431        from sympy.abc import l
432    
433        if self.dim == 1:
434            xi = symbols('xi', real=True, positive=True)
435            l = symbols('l', real=True, positive=True)
436            p = self.symbol
437            p_scaled = p.subs(xi, l * xi)
438            ratio = simplify(p_scaled / p)
439            if ratio.has(xi):
440                return False, None
441            try:
442                deg = simplify(ratio).as_base_exp()[1]
443                return True, deg
444            except Exception:
445                return False, None
446    
447        elif self.dim == 2:
448            xi, eta = symbols('xi eta', real=True, positive=True)
449            l = symbols('l', real=True, positive=True)
450            p = self.symbol
451            p_scaled = p.subs({xi: l * xi, eta: l * eta})
452            ratio = simplify(p_scaled / p)
453            # If ratio == l**m with no (xi, eta) left, it's homogeneous
454            if ratio.has(xi, eta):
455                return False, None
456            try:
457                base, exp = ratio.as_base_exp()
458                if base == l:
459                    return True, exp
460            except Exception:
461                pass
462            return False, None

Check whether the symbol is homogeneous in the frequency variables.

Returns

(bool, Rational or float or None) Tuple (is_homogeneous, degree) where: - is_homogeneous: True if the symbol satisfies p(λξ, λη) = λ^m * p(ξ, η) - degree: the detected degree m if homogeneous, or None

def symbol_order(self, max_order=10, tol=0.001):
464    def symbol_order(self, max_order=10, tol=1e-3):
465        """
466        Estimate the homogeneity order of the pseudo-differential symbol in high-frequency asymptotics.
467    
468        This method attempts to determine the leading-order behavior of the symbol p(x, ξ) or p(x, y, ξ, η)
469        as |ξ| → ∞ (in 1D) or |(ξ, η)| → ∞ (in 2D). The returned value represents the asymptotic growth or decay rate,
470        which is essential for understanding the regularity and mapping properties of the corresponding operator.
471    
472        The function uses symbolic preprocessing to ensure proper factorization of frequency variables,
473        especially in sqrt and power expressions, to avoid erroneous order detection (e.g., due to hidden scaling).
474    
475        Parameters
476        ----------
477        max_order : int, optional
478            Maximum number of terms to consider in the series expansion. Default is 10.
479        tol : float, optional
480            Tolerance threshold for evaluating the coefficient magnitude. If the coefficient is too small,
481            the detected order may be discarded. Default is 1e-3.
482    
483        Returns
484        -------
485        float or None
486            - If the symbol is homogeneous, returns its exact homogeneity degree as a float.
487            - Otherwise, estimates the dominant asymptotic order from leading terms in the expansion.
488            - Returns None if no valid order could be determined.
489    
490        Notes
491        -----
492        - In 1D:
493            Two strategies are used:
494                1. Expand directly in xi at infinity.
495                2. Substitute xi = 1/z and expand around z = 0.
496    
497        - In 2D:
498            - Transform the symbol into polar coordinates: (xi, eta) = rho*(cos(theta), sin(theta)).
499            - Expand in rho at infinity, then extract the leading term's power.
500            - An alternative substitution using 1/z is also tried if the first method fails.
501    
502        - Preprocessing steps:
503            - Sqrt expressions involving frequencies are rewritten to isolate the leading variable.
504            - Power expressions are factored explicitly to ensure correct symbolic scaling.
505    
506        - If the symbol is not homogeneous, a warning is issued, and the result should be interpreted with care.
507        
508        - For non-homogeneous symbols, only the principal asymptotic term is considered.
509    
510        Raises
511        ------
512        NotImplementedError
513            If the spatial dimension is neither 1 nor 2.
514        """
515        from sympy import (
516            symbols, series, simplify, sqrt, cos, sin, oo, powdenest, radsimp,
517            expand, expand_power_base
518        )
519    
520        def preprocess_sqrt(expr, freq):
521            return expr.replace(
522                lambda e: e.func == sqrt and freq in e.free_symbols,
523                lambda e: freq * sqrt(1 + (e.args[0] - freq**2) / freq**2)
524            )
525    
526        def preprocess_power(expr, freq):
527            return expr.replace(
528                lambda e: e.is_Pow and freq in e.free_symbols,
529                lambda e: freq**e.exp * (1 + e.base / freq**e.base.as_powers_dict().get(freq, 0))**e.exp
530            )
531    
532        def validate_order(power, coeff, vars_x, tol):
533            if power is None:
534                return None
535            if any(v in coeff.free_symbols for v in vars_x):
536                print("⚠️ Coefficient depends on spatial variables; ignoring")
537                return None
538            try:
539                coeff_val = abs(float(coeff.evalf()))
540                if coeff_val < tol:
541                    print(f"⚠️ Coefficient too small ({coeff_val:.2e} < {tol})")
542                    return None
543            except Exception as e:
544                print(f"⚠️ Coefficient evaluation failed: {e}")
545                return None
546            return int(power) if power == int(power) else float(power)
547    
548        # Homogeneity check
549        is_homog, degree = self.is_homogeneous()
550        if is_homog:
551            return float(degree)
552        else:
553            print("⚠️ The symbol is not homogeneous. The asymptotic order is not well defined.")
554    
555        if self.dim == 1:
556            x = self.vars_x[0]
557            xi = symbols('xi', real=True, positive=True)
558    
559            try:
560                print("1D symbol_order - method 1")
561                expr = preprocess_sqrt(self.symbol, xi)
562                s = series(expr, xi, oo, n=max_order).removeO()
563                lead = simplify(powdenest(s.as_leading_term(xi), force=True))
564                power = lead.as_powers_dict().get(xi, None)
565                coeff = lead / xi**power if power is not None else 0
566                print("lead =", lead)
567                print("power =", power)
568                print("coeff =", coeff)
569                order = validate_order(power, coeff, [x], tol)
570                if order is not None:
571                    return order
572            except Exception:
573                pass
574    
575            try:
576                print("1D symbol_order - method 2")
577                z = symbols('z', real=True, positive=True)
578                expr_z = preprocess_sqrt(self.symbol.subs(xi, 1/z), 1/z)
579                s = series(expr_z, z, 0, n=max_order).removeO()
580                lead = simplify(powdenest(s.as_leading_term(z), force=True))
581                power = lead.as_powers_dict().get(z, None)
582                coeff = lead / z**power if power is not None else 0
583                print("lead =", lead)
584                print("power =", power)
585                print("coeff =", coeff)
586                order = validate_order(power, coeff, [x], tol)
587                if order is not None:
588                    return -order
589            except Exception as e:
590                print(f"⚠️ fallback z failed: {e}")
591            return None
592    
593        elif self.dim == 2:
594            x, y = self.vars_x
595            xi, eta = symbols('xi eta', real=True, positive=True)
596            rho, theta = symbols('rho theta', real=True, positive=True)
597    
598            try:
599                print("2D symbol_order - method 1")
600                p_rho = self.symbol.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
601                p_rho = preprocess_power(preprocess_sqrt(p_rho, rho), rho)
602                s = series(simplify(p_rho), rho, oo, n=max_order).removeO()
603                lead = radsimp(simplify(powdenest(s.as_leading_term(rho), force=True)))
604                power = lead.as_powers_dict().get(rho, None)
605                coeff = lead / rho**power if power is not None else 0
606                print("lead =", lead)
607                print("power =", power)
608                print("coeff =", coeff)
609                order = validate_order(power, coeff, [x, y], tol)
610                if order is not None:
611                    return order
612            except Exception as e:
613                print(f"⚠️ polar expansion failed: {e}")
614    
615            try:
616                print("2D symbol_order - method 2")
617                z = symbols('z', real=True, positive=True)
618                xi_eta = {xi: (1/z) * cos(theta), eta: (1/z) * sin(theta)}
619                p_rho = preprocess_sqrt(self.symbol.subs(xi_eta), 1/z)
620                s = series(simplify(p_rho), z, 0, n=max_order).removeO()
621                lead = radsimp(simplify(powdenest(s.as_leading_term(z), force=True)))
622                power = lead.as_powers_dict().get(z, None)
623                coeff = lead / z**power if power is not None else 0
624                print("lead =", lead)
625                print("power =", power)
626                print("coeff =", coeff)
627                order = validate_order(power, coeff, [x, y], tol)
628                if order is not None:
629                    return -order
630            except Exception as e:
631                print(f"⚠️ fallback z (2D) failed: {e}")
632            return None
633    
634        else:
635            raise NotImplementedError("Only 1D and 2D supported.")

Estimate the homogeneity order of the pseudo-differential symbol in high-frequency asymptotics.

This method attempts to determine the leading-order behavior of the symbol p(x, ξ) or p(x, y, ξ, η) as |ξ| → ∞ (in 1D) or |(ξ, η)| → ∞ (in 2D). The returned value represents the asymptotic growth or decay rate, which is essential for understanding the regularity and mapping properties of the corresponding operator.

The function uses symbolic preprocessing to ensure proper factorization of frequency variables, especially in sqrt and power expressions, to avoid erroneous order detection (e.g., due to hidden scaling).

Parameters

max_order : int, optional Maximum number of terms to consider in the series expansion. Default is 10. tol : float, optional Tolerance threshold for evaluating the coefficient magnitude. If the coefficient is too small, the detected order may be discarded. Default is 1e-3.

Returns

float or None - If the symbol is homogeneous, returns its exact homogeneity degree as a float. - Otherwise, estimates the dominant asymptotic order from leading terms in the expansion. - Returns None if no valid order could be determined.

Notes

  • In 1D: Two strategies are used: 1. Expand directly in xi at infinity. 2. Substitute xi = 1/z and expand around z = 0.

  • In 2D:

    • Transform the symbol into polar coordinates: (xi, eta) = rho*(cos(theta), sin(theta)).
    • Expand in rho at infinity, then extract the leading term's power.
    • An alternative substitution using 1/z is also tried if the first method fails.
  • Preprocessing steps:

    • Sqrt expressions involving frequencies are rewritten to isolate the leading variable.
    • Power expressions are factored explicitly to ensure correct symbolic scaling.
  • If the symbol is not homogeneous, a warning is issued, and the result should be interpreted with care.

  • For non-homogeneous symbols, only the principal asymptotic term is considered.

Raises

NotImplementedError If the spatial dimension is neither 1 nor 2.

def asymptotic_expansion(self, order=3):
638    def asymptotic_expansion(self, order=3):
639        """
640        Compute the asymptotic expansion of the symbol as |ξ| → ∞ (high-frequency regime).
641    
642        This method expands the pseudo-differential symbol in inverse powers of the 
643        frequency variable(s), either in 1D or 2D. It handles both polynomial and 
644        exponential symbols by performing a series expansion in 1/|ξ| up to the specified order.
645    
646        The expansion is performed directly in Cartesian coordinates for 1D symbols.
647        For 2D symbols, the method uses polar coordinates (ρ, θ) to perform the expansion 
648        at infinity in ρ, then converts the result back to Cartesian coordinates.
649    
650        Parameters
651        ----------
652        order : int, optional
653            Maximum order of the asymptotic expansion. Default is 3.
654    
655        Returns
656        -------
657        sympy.Expr
658            The asymptotic expansion of the symbol up to the given order, expressed in Cartesian coordinates.
659            If expansion fails, returns the original unexpanded symbol.
660    
661        Notes:
662        - In 1D: expansion is performed directly in terms of ξ.
663        - In 2D: the symbol is first rewritten in polar coordinates (ρ,θ), expanded asymptotically 
664          in ρ → ∞, then converted back to Cartesian coordinates (ξ,η).
665        - Handles special case when the symbol is an exponential function by expanding its argument.
666        - Symbolic normalization is applied early (via `simplify`) for 2D expressions to improve convergence.
667        - Robust to failures: catches exceptions and issues warnings instead of raising errors.
668        - Final expression is simplified using `powdenest` and `expand` for improved readability.
669        """
670        p = self.symbol
671    
672        if self.dim == 1:
673            xi = symbols('xi', real=True, positive=True)
674    
675            try:
676                # Case: exponential function
677                if p.func == exp and len(p.args) == 1:
678                    arg = p.args[0]
679                    arg_series = series(arg, xi, oo, n=order).removeO()
680                    expanded = series(exp(expand(arg_series)), xi, oo, n=order).removeO()
681                    return simplify(powdenest(expanded, force=True))
682                else:
683                    expanded = series(p, xi, oo, n=order).removeO()
684                    return simplify(powdenest(expanded, force=True))
685    
686            except Exception as e:
687                print(f"Warning: 1D expansion failed: {e}")
688                return p
689    
690        elif self.dim == 2:
691            xi, eta = symbols('xi eta', real=True, positive=True)
692            rho, theta = symbols('rho theta', real=True, positive=True)
693    
694            # Normalize before substitution
695            p = simplify(p)
696    
697            # Substitute polar coordinates
698            p_polar = p.subs({
699                xi: rho * cos(theta),
700                eta: rho * sin(theta)
701            })
702    
703            try:
704                # Handle exponentials
705                if p_polar.func == exp and len(p_polar.args) == 1:
706                    arg = p_polar.args[0]
707                    arg_series = series(arg, rho, oo, n=order).removeO()
708                    expanded = series(exp(expand(arg_series)), rho, oo, n=order).removeO()
709                else:
710                    expanded = series(p_polar, rho, oo, n=order).removeO()
711    
712                # Convert back to Cartesian
713                norm = sqrt(xi**2 + eta**2)
714                expansion_cart = expanded.subs({
715                    rho: norm,
716                    cos(theta): xi / norm,
717                    sin(theta): eta / norm
718                })
719    
720                # Final simplifications
721                result = simplify(powdenest(expansion_cart, force=True))
722                result = expand(result)
723                return result
724    
725            except Exception as e:
726                print(f"Warning: 2D expansion failed: {e}")
727                return p  

Compute the asymptotic expansion of the symbol as |ξ| → ∞ (high-frequency regime).

This method expands the pseudo-differential symbol in inverse powers of the frequency variable(s), either in 1D or 2D. It handles both polynomial and exponential symbols by performing a series expansion in 1/|ξ| up to the specified order.

The expansion is performed directly in Cartesian coordinates for 1D symbols. For 2D symbols, the method uses polar coordinates (ρ, θ) to perform the expansion at infinity in ρ, then converts the result back to Cartesian coordinates.

Parameters

order : int, optional Maximum order of the asymptotic expansion. Default is 3.

Returns

sympy.Expr The asymptotic expansion of the symbol up to the given order, expressed in Cartesian coordinates. If expansion fails, returns the original unexpanded symbol.

Notes:

  • In 1D: expansion is performed directly in terms of ξ.
  • In 2D: the symbol is first rewritten in polar coordinates (ρ,θ), expanded asymptotically in ρ → ∞, then converted back to Cartesian coordinates (ξ,η).
  • Handles special case when the symbol is an exponential function by expanding its argument.
  • Symbolic normalization is applied early (via simplify) for 2D expressions to improve convergence.
  • Robust to failures: catches exceptions and issues warnings instead of raising errors.
  • Final expression is simplified using powdenest and expand for improved readability.
def compose_asymptotic(self, other, order=1, mode='kn', sign_convention=None):
729    def compose_asymptotic(self, other, order=1, mode='kn', sign_convention=None):
730        """
731        Compose two pseudo-differential operators using an asymptotic expansion
732        in the chosen quantization scheme (Kohn–Nirenberg or Weyl).
733    
734        Parameters
735        ----------
736        other : PseudoDifferentialOperator
737            The operator to compose with this one.
738        order : int, default=1
739            Maximum order of the asymptotic expansion.
740        mode : {'kn', 'weyl'}, default='kn'
741            Quantization mode:
742            - 'kn' : Kohn–Nirenberg quantization (left-quantized)
743            - 'weyl' : Weyl symmetric quantization
744        sign_convention : {'standard', 'inverse'}, optional
745            Controls the phase factor convention for the KN case:
746            - 'standard' → (i)^(-n), gives [x, ξ] = +i (physics convention)
747            - 'inverse' → (i)^(+n), gives [x, ξ] = -i (mathematical adjoint convention)
748            If None, defaults to 'standard'.
749    
750        Returns
751        -------
752        sympy.Expr
753            Symbolic expression for the composed symbol up to the given order.
754    
755        Notes
756        -----
757        - In 1D (Kohn–Nirenberg):
758            (p ∘ q)(x, ξ) ~ Σₙ (1/n!) (i sgn)^n ∂_ξⁿ p(x, ξ) ∂_xⁿ q(x, ξ)
759        - In 1D (Weyl):
760            (p # q)(x, ξ) = exp[(i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q)] p(x, ξ) q(x, ξ)
761            truncated at given order.
762    
763        Examples
764        --------
765        X = a*x, Y = b*ξ
766        X_op.compose_asymptotic(Y_op, order=3, mode='weyl')
767        """
768    
769        from sympy import diff, factorial, simplify, symbols
770    
771        assert self.dim == other.dim, "Operator dimensions must match"
772        p, q = self.symbol, other.symbol
773    
774        # Default sign convention
775        if sign_convention is None:
776            sign_convention = 'standard'
777        sign = -1 if sign_convention == 'standard' else +1
778    
779        # --- 1D case ---
780        if self.dim == 1:
781            x = self.vars_x[0]
782            xi = symbols('xi', real=True)
783            result = 0
784    
785            if mode == 'kn':  # Kohn–Nirenberg
786                for n in range(order + 1):
787                    term = (1 / factorial(n)) * diff(p, xi, n) * diff(q, x, n) * (1j) ** (sign * n)
788                    result += term
789    
790            elif mode == 'weyl':  # Weyl symmetric composition
791                # Weyl star product: exp((i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q))
792                result = 0
793                for n in range(order + 1):
794                    for k in range(n + 1):
795                        # k derivatives acting as (∂_ξ^k p)(∂_x^(n−k) q)
796                        coeff = (1 / (factorial(k) * factorial(n - k))) * ((1j / 2) ** n) * ((-1) ** (n - k))
797                        term = coeff * diff(p, xi, k, x, n - k, evaluate=True) * diff(q, x, k, xi, n - k, evaluate=True)
798                        result += term
799    
800            else:
801                raise ValueError("mode must be either 'kn' or 'weyl'")
802    
803            return simplify(result)
804    
805        # --- 2D case ---
806        elif self.dim == 2:
807            x, y = self.vars_x
808            xi, eta = symbols('xi eta', real=True)
809            result = 0
810    
811            if mode == 'kn':
812                for n in range(order + 1):
813                    for i in range(n + 1):
814                        j = n - i
815                        term = (1 / (factorial(i) * factorial(j))) * \
816                               diff(p, xi, i, eta, j) * diff(q, x, i, y, j) * (1j) ** (sign * n)
817                        result += term
818    
819            elif mode == 'weyl':
820                for n in range(order + 1):
821                    for i in range(n + 1):
822                        j = n - i
823                        coeff = (1 / (factorial(i) * factorial(j))) * ((1j / 2) ** n) * ((-1) ** (n - i))
824                        term = coeff * diff(p, xi, i, eta, j, x, 0, y, 0) * diff(q, x, i, y, j, xi, 0, eta, 0)
825                        result += term
826            else:
827                raise ValueError("mode must be either 'kn' or 'weyl'")
828    
829            return simplify(result)
830    
831        else:
832            raise NotImplementedError("Only 1D and 2D cases are implemented")

Compose two pseudo-differential operators using an asymptotic expansion in the chosen quantization scheme (Kohn–Nirenberg or Weyl).

Parameters

other : PseudoDifferentialOperator The operator to compose with this one. order : int, default=1 Maximum order of the asymptotic expansion. mode : {'kn', 'weyl'}, default='kn' Quantization mode: - 'kn' : Kohn–Nirenberg quantization (left-quantized) - 'weyl' : Weyl symmetric quantization sign_convention : {'standard', 'inverse'}, optional Controls the phase factor convention for the KN case: - 'standard' → (i)^(-n), gives [x, ξ] = +i (physics convention) - 'inverse' → (i)^(+n), gives [x, ξ] = -i (mathematical adjoint convention) If None, defaults to 'standard'.

Returns

sympy.Expr Symbolic expression for the composed symbol up to the given order.

Notes

  • In 1D (Kohn–Nirenberg): (p ∘ q)(x, ξ) ~ Σₙ (1/n!) (i sgn)^n ∂_ξⁿ p(x, ξ) ∂_xⁿ q(x, ξ)
  • In 1D (Weyl): (p # q)(x, ξ) = exp[(i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q)] p(x, ξ) q(x, ξ) truncated at given order.

Examples

X = ax, Y = bξ X_op.compose_asymptotic(Y_op, order=3, mode='weyl')

def commutator_symbolic(self, other, order=1, mode='kn', sign_convention=None):
834    def commutator_symbolic(self, other, order=1, mode='kn', sign_convention=None):
835        """
836        Compute the symbolic commutator [A, B] = A∘B − B∘A of two pseudo-differential operators
837        using formal asymptotic expansion of their composition symbols.
838    
839        This method computes the asymptotic expansion of the commutator's symbol up to a given 
840        order, based on the symbolic calculus of pseudo-differential operators in the 
841        Kohn–Nirenberg quantization. The result is a purely symbolic sympy expression that 
842        captures the leading-order noncommutativity of the operators.
843    
844        Parameters
845        ----------
846        other : PseudoDifferentialOperator
847            The pseudo-differential operator B to commute with this operator A.
848        order : int, default=1
849            Maximum order of the asymptotic expansion. 
850            - order=1 yields the leading term proportional to the Poisson bracket {p, q}.
851            - Higher orders include correction terms involving higher mixed derivatives.
852    
853        Returns
854        -------
855        sympy.Expr
856            Symbolic expression for the asymptotic expansion of the commutator symbol 
857            σ([A,B]) = σ(A∘B − B∘A).
858    
859        """
860        assert self.dim == other.dim, "Operator dimensions must match"
861        p, q = self.symbol, other.symbol
862    
863        pq = self.compose_asymptotic(other, order=order, mode=mode, sign_convention=sign_convention)
864        qp = other.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
865        
866        comm_symbol = simplify(pq-qp)
867
868        return comm_symbol

Compute the symbolic commutator [A, B] = A∘B − B∘A of two pseudo-differential operators using formal asymptotic expansion of their composition symbols.

This method computes the asymptotic expansion of the commutator's symbol up to a given order, based on the symbolic calculus of pseudo-differential operators in the Kohn–Nirenberg quantization. The result is a purely symbolic sympy expression that captures the leading-order noncommutativity of the operators.

Parameters

other : PseudoDifferentialOperator The pseudo-differential operator B to commute with this operator A. order : int, default=1 Maximum order of the asymptotic expansion. - order=1 yields the leading term proportional to the Poisson bracket {p, q}. - Higher orders include correction terms involving higher mixed derivatives.

Returns

sympy.Expr Symbolic expression for the asymptotic expansion of the commutator symbol σ([A,B]) = σ(A∘B − B∘A).

def right_inverse_asymptotic(self, order=1):
870    def right_inverse_asymptotic(self, order=1):
871        """
872        Construct a formal right inverse R of the pseudo-differential operator P such that 
873        the composition P ∘ R equals the identity plus a smoothing operator of order -order.
874    
875        This method computes an asymptotic expansion for the right inverse using recursive 
876        corrections based on derivatives of the symbol p(x, ξ) and lower-order terms of R.
877    
878        Parameters
879        ----------
880        order : int
881            Number of terms to include in the asymptotic expansion. Higher values improve 
882            approximation at the cost of complexity and computational effort.
883    
884        Returns
885        -------
886        sympy.Expr
887            The symbolic expression representing the formal right inverse R(x, ξ), which satisfies:
888            P ∘ R = Id + O(⟨ξ⟩^{-order}), where ⟨ξ⟩ = (1 + |ξ|²)^{1/2}.
889    
890        Notes
891        -----
892        - In 1D: The recursion involves spatial derivatives of R and derivatives of p with respect to ξ.
893        - In 2D: The multi-index generalization is used with mixed derivatives in ξ and η.
894        - The construction relies on the non-vanishing of the principal symbol p to ensure invertibility.
895        - Each term in the expansion corresponds to higher-order corrections involving commutators 
896          between the operator P and the current approximation of R.
897        """
898        p = self.symbol
899        if self.dim == 1:
900            x = self.vars_x[0]
901            xi = symbols('xi', real=True)
902            r = 1 / p.subs(xi, xi)  # r0
903            R = r
904            for n in range(1, order + 1):
905                term = 0
906                for k in range(1, n + 1):
907                    coeff = (1j)**(-k) / factorial(k)
908                    inner = diff(p, xi, k) * diff(R, x, k)
909                    term += coeff * inner
910                R = R - r * term
911        elif self.dim == 2:
912            x, y = self.vars_x
913            xi, eta = symbols('xi eta', real=True)
914            r = 1 / p.subs({xi: xi, eta: eta})
915            R = r
916            for n in range(1, order + 1):
917                term = 0
918                for k1 in range(n + 1):
919                    for k2 in range(n + 1 - k1):
920                        if k1 + k2 == 0: continue
921                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
922                        dp = diff(p, xi, k1, eta, k2)
923                        dR = diff(R, x, k1, y, k2)
924                        term += coeff * dp * dR
925                R = R - r * term
926        return R

Construct a formal right inverse R of the pseudo-differential operator P such that the composition P ∘ R equals the identity plus a smoothing operator of order -order.

This method computes an asymptotic expansion for the right inverse using recursive corrections based on derivatives of the symbol p(x, ξ) and lower-order terms of R.

Parameters

order : int Number of terms to include in the asymptotic expansion. Higher values improve approximation at the cost of complexity and computational effort.

Returns

sympy.Expr The symbolic expression representing the formal right inverse R(x, ξ), which satisfies: P ∘ R = Id + O(⟨ξ⟩^{-order}), where ⟨ξ⟩ = (1 + |ξ|²)^{1/2}.

Notes

  • In 1D: The recursion involves spatial derivatives of R and derivatives of p with respect to ξ.
  • In 2D: The multi-index generalization is used with mixed derivatives in ξ and η.
  • The construction relies on the non-vanishing of the principal symbol p to ensure invertibility.
  • Each term in the expansion corresponds to higher-order corrections involving commutators between the operator P and the current approximation of R.
def left_inverse_asymptotic(self, order=1):
928    def left_inverse_asymptotic(self, order=1):
929        """
930        Construct a formal left inverse L such that the composition L ∘ P equals the identity 
931        operator up to terms of order ξ^{-order}. This expansion is performed asymptotically 
932        at infinity in the frequency variable(s).
933    
934        The left inverse is built iteratively using symbolic differentiation and the 
935        method of asymptotic expansions for pseudo-differential operators. It ensures that:
936        
937            L(P(x,ξ),x,D) ∘ P(x,D) = Id + smoothing operator of order -order
938    
939        Parameters
940        ----------
941        order : int, optional
942            Maximum number of terms in the asymptotic expansion (default is 1). Higher values 
943            yield more accurate inverses at the cost of increased computational complexity.
944    
945        Returns
946        -------
947        sympy.Expr
948            Symbolic expression representing the principal symbol of the formal left inverse 
949            operator L(x,ξ). This expression depends on spatial variables and frequencies, 
950            and includes correction terms up to the specified order.
951    
952        Notes
953        -----
954        - In 1D: Uses recursive application of the Leibniz formula for symbols.
955        - In 2D: Generalizes to multi-indices for mixed derivatives in (x,y) and (ξ,η).
956        - Each term involves combinations of derivatives of the original symbol p(x,ξ) and 
957          previously computed terms of the inverse.
958        - Coefficients include powers of 1j (i) and factorial normalization for derivative terms.
959        """
960        p = self.symbol
961        if self.dim == 1:
962            x = self.vars_x[0]
963            xi = symbols('xi', real=True)
964            l = 1 / p.subs(xi, xi)
965            L = l
966            for n in range(1, order + 1):
967                term = 0
968                for k in range(1, n + 1):
969                    coeff = (1j)**(-k) / factorial(k)
970                    inner = diff(L, xi, k) * diff(p, x, k)
971                    term += coeff * inner
972                L = L - term * l
973        elif self.dim == 2:
974            x, y = self.vars_x
975            xi, eta = symbols('xi eta', real=True)
976            l = 1 / p.subs({xi: xi, eta: eta})
977            L = l
978            for n in range(1, order + 1):
979                term = 0
980                for k1 in range(n + 1):
981                    for k2 in range(n + 1 - k1):
982                        if k1 + k2 == 0: continue
983                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
984                        dp = diff(p, x, k1, y, k2)
985                        dL = diff(L, xi, k1, eta, k2)
986                        term += coeff * dL * dp
987                L = L - term * l
988        return L

Construct a formal left inverse L such that the composition L ∘ P equals the identity operator up to terms of order ξ^{-order}. This expansion is performed asymptotically at infinity in the frequency variable(s).

The left inverse is built iteratively using symbolic differentiation and the method of asymptotic expansions for pseudo-differential operators. It ensures that:

L(P(x,ξ),x,D) ∘ P(x,D) = Id + smoothing operator of order -order

Parameters

order : int, optional Maximum number of terms in the asymptotic expansion (default is 1). Higher values yield more accurate inverses at the cost of increased computational complexity.

Returns

sympy.Expr Symbolic expression representing the principal symbol of the formal left inverse operator L(x,ξ). This expression depends on spatial variables and frequencies, and includes correction terms up to the specified order.

Notes

  • In 1D: Uses recursive application of the Leibniz formula for symbols.
  • In 2D: Generalizes to multi-indices for mixed derivatives in (x,y) and (ξ,η).
  • Each term involves combinations of derivatives of the original symbol p(x,ξ) and previously computed terms of the inverse.
  • Coefficients include powers of 1j (i) and factorial normalization for derivative terms.
def formal_adjoint(self):
 990    def formal_adjoint(self):
 991        """
 992        Compute the formal adjoint symbol P* of the pseudo-differential operator.
 993
 994        The adjoint is defined such that for any test functions u and v,
 995        ⟨P u, v⟩ = ⟨u, P* v⟩ holds in the distributional sense. This is obtained by 
 996        taking the complex conjugate of the symbol and expanding it asymptotically 
 997        at infinity to ensure proper behavior under integration by parts.
 998
 999        Returns
1000        -------
1001        sympy.Expr
1002            The adjoint symbol P*(x, ξ) in 1D or P*(x, y, ξ, η) in 2D.
1003        
1004        Notes:
1005        - In 1D, the expansion is performed in powers of 1/|ξ|.
1006        - In 2D, the expansion is radial in |ξ| = sqrt(ξ² + η²).
1007        - This method ensures symbolic simplifications for readability and efficiency.
1008        """
1009        p = self.symbol
1010        if self.dim == 1:
1011            x, = self.vars_x
1012            xi = symbols('xi', real=True)
1013            p_star = conjugate(p)
1014            p_star = simplify(series(p_star, xi, oo, n=6).removeO())
1015            return p_star
1016        elif self.dim == 2:
1017            x, y = self.vars_x
1018            xi, eta = symbols('xi eta', real=True)
1019            p_star = conjugate(p)
1020            p_star = simplify(series(p_star, sqrt(xi**2 + eta**2), oo, n=6).removeO())
1021            return p_star

Compute the formal adjoint symbol P* of the pseudo-differential operator.

The adjoint is defined such that for any test functions u and v, ⟨P u, v⟩ = ⟨u, P* v⟩ holds in the distributional sense. This is obtained by taking the complex conjugate of the symbol and expanding it asymptotically at infinity to ensure proper behavior under integration by parts.

Returns

sympy.Expr The adjoint symbol P(x, ξ) in 1D or P(x, y, ξ, η) in 2D.

Notes:

  • In 1D, the expansion is performed in powers of 1/|ξ|.
  • In 2D, the expansion is radial in |ξ| = sqrt(ξ² + η²).
  • This method ensures symbolic simplifications for readability and efficiency.
def exponential_symbol(self, t=1.0, order=1, mode='kn', sign_convention=None):
1023    def exponential_symbol(self, t=1.0, order=1, mode='kn', sign_convention=None):
1024        """
1025        Compute the symbol of exp(tP) using asymptotic expansion methods.
1026        
1027        This method calculates the exponential of a pseudo-differential operator 
1028        using either a direct power series expansion or a Magnus expansion, 
1029        depending on the structure of the symbol. The result is valid up to 
1030        the specified asymptotic order.
1031        
1032        Parameters
1033        ----------
1034        t : float or sympy.Symbol, default=1.0
1035            Time or evolution parameter. Common uses:
1036            - t = -i*τ for Schrödinger evolution: exp(-iτH)
1037            - t = τ for heat/diffusion: exp(τΔ)
1038            - t for general propagators
1039        order : int, default=3
1040            Maximum order of the asymptotic expansion. Higher orders include 
1041            more composition terms, improving accuracy for small t or when 
1042            non-commutativity effects are significant.
1043        
1044        Returns
1045        -------
1046        sympy.Expr
1047            Symbolic expression for the exponential operator symbol, computed 
1048            as an asymptotic series up to the specified order.
1049        
1050        Notes
1051        -----
1052        - For commutative symbols (e.g., pure multiplication operators), the 
1053          exponential is exact: exp(tP) = exp(t*p(x,ξ)).
1054        
1055        - For general non-commutative operators, the method uses the BCH-type 
1056          expansion via iterated composition:
1057          exp(tP) ~ I + tP + (t²/2!)P∘P + (t³/3!)P∘P∘P + ...
1058          
1059        - Each power P^n is computed via compose_asymptotic, which accounts 
1060          for the non-commutativity through derivative terms.
1061        
1062        - The expansion is valid for |t| small enough or when the symbol has 
1063          appropriate decay/growth properties.
1064        
1065        - In quantum mechanics (Schrödinger): U(t) = exp(-itH/ℏ) represents 
1066          the time evolution operator.
1067        
1068        - In parabolic PDEs (heat equation): exp(tΔ) is the heat kernel.
1069
1070        """
1071        if self.dim == 1:
1072            x = self.vars_x[0]
1073            xi = symbols('xi', real=True)
1074            
1075            # Initialize with identity
1076            result = 1
1077            
1078            # First order term: tP
1079            current_power = self.symbol
1080            result += t * current_power
1081            
1082            # Higher order terms: (t^n/n!) P^n computed via composition
1083            for n in range(2, order + 1):
1084                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1085                # We use a temporary operator for composition
1086                temp_op = PseudoDifferentialOperator(
1087                    current_power, [x], mode='symbol'
1088                )
1089                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1090                
1091                # Add term (t^n/n!) * P^n
1092                coeff = t**n / factorial(n)
1093                result += coeff * current_power
1094            
1095            return simplify(result)
1096        
1097        elif self.dim == 2:
1098            x, y = self.vars_x
1099            xi, eta = symbols('xi eta', real=True)
1100            
1101            # Initialize with identity
1102            result = 1
1103            
1104            # First order term: tP
1105            current_power = self.symbol
1106            result += t * current_power
1107            
1108            # Higher order terms: (t^n/n!) P^n computed via composition
1109            for n in range(2, order + 1):
1110                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1111                temp_op = PseudoDifferentialOperator(
1112                    current_power, [x, y], mode='symbol'
1113                )
1114                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1115                
1116                # Add term (t^n/n!) * P^n
1117                coeff = t**n / factorial(n)
1118                result += coeff * current_power
1119            
1120            return simplify(result)
1121        
1122        else:
1123            raise NotImplementedError("Only 1D and 2D operators are supported")

Compute the symbol of exp(tP) using asymptotic expansion methods.

This method calculates the exponential of a pseudo-differential operator using either a direct power series expansion or a Magnus expansion, depending on the structure of the symbol. The result is valid up to the specified asymptotic order.

Parameters

t : float or sympy.Symbol, default=1.0 Time or evolution parameter. Common uses: - t = -i*τ for Schrödinger evolution: exp(-iτH) - t = τ for heat/diffusion: exp(τΔ) - t for general propagators order : int, default=3 Maximum order of the asymptotic expansion. Higher orders include more composition terms, improving accuracy for small t or when non-commutativity effects are significant.

Returns

sympy.Expr Symbolic expression for the exponential operator symbol, computed as an asymptotic series up to the specified order.

Notes

  • For commutative symbols (e.g., pure multiplication operators), the exponential is exact: exp(tP) = exp(t*p(x,ξ)).

  • For general non-commutative operators, the method uses the BCH-type expansion via iterated composition: exp(tP) ~ I + tP + (t²/2!)P∘P + (t³/3!)P∘P∘P + ...

  • Each power P^n is computed via compose_asymptotic, which accounts for the non-commutativity through derivative terms.

  • The expansion is valid for |t| small enough or when the symbol has appropriate decay/growth properties.

  • In quantum mechanics (Schrödinger): U(t) = exp(-itH/ℏ) represents the time evolution operator.

  • In parabolic PDEs (heat equation): exp(tΔ) is the heat kernel.

def trace_formula( self, volume_element=None, numerical=False, x_bounds=None, xi_bounds=None):
1125    def trace_formula(self, volume_element=None, numerical=False, 
1126                      x_bounds=None, xi_bounds=None):
1127        """
1128        Compute the semiclassical trace of the pseudo-differential operator.
1129        
1130        The trace formula relates the quantum trace of an operator to a 
1131        phase-space integral of its symbol, providing a fundamental link 
1132        between classical and quantum mechanics. This implementation supports 
1133        both symbolic and numerical integration.
1134        
1135        Parameters
1136        ----------
1137        volume_element : sympy.Expr, optional
1138            Custom volume element for the phase space integration. If None, 
1139            uses the standard Liouville measure dx dξ/(2π)^d.
1140        numerical : bool, default=False
1141            If True, perform numerical integration over specified bounds.
1142            If False, attempt symbolic integration (may fail for complex symbols).
1143        x_bounds : tuple of tuples, optional
1144            Spatial integration bounds. For 1D: ((x_min, x_max),)
1145            For 2D: ((x_min, x_max), (y_min, y_max))
1146            Required if numerical=True.
1147        xi_bounds : tuple of tuples, optional
1148            Frequency integration bounds. For 1D: ((xi_min, xi_max),)
1149            For 2D: ((xi_min, xi_max), (eta_min, eta_max))
1150            Required if numerical=True.
1151        
1152        Returns
1153        -------
1154        sympy.Expr or float
1155            The trace of the operator. Returns a symbolic expression if 
1156            numerical=False, or a float if numerical=True.
1157        
1158        Notes
1159        -----
1160        - The semiclassical trace formula states:
1161          Tr(P) = (2π)^{-d} ∫∫ p(x,ξ) dx dξ
1162          where d is the spatial dimension and p(x,ξ) is the operator symbol.
1163        
1164        - For 1D: Tr(P) = (1/2π) ∫_{-∞}^{∞} ∫_{-∞}^{∞} p(x,ξ) dx dξ
1165        
1166        - For 2D: Tr(P) = (1/4π²) ∫∫∫∫ p(x,y,ξ,η) dx dy dξ dη
1167        
1168        - This formula is exact for trace-class operators and provides an 
1169          asymptotic approximation for general pseudo-differential operators.
1170        
1171        - Physical interpretation: the trace counts the "number of states" 
1172          weighted by the observable p(x,ξ).
1173        
1174        - For projection operators (χ_Ω with χ² = χ), the trace gives the 
1175          dimension of the range, related to the phase space volume of Ω.
1176        
1177        - The factor (2π)^{-d} comes from the quantum normalization of 
1178          coherent states / Weyl quantization.
1179        """
1180        from sympy import integrate, simplify, lambdify
1181        from scipy.integrate import dblquad, nquad
1182        
1183        p = self.symbol
1184        
1185        if numerical:
1186            if x_bounds is None or xi_bounds is None:
1187                raise ValueError(
1188                    "x_bounds and xi_bounds must be provided for numerical integration"
1189                )
1190        
1191        if self.dim == 1:
1192            x, = self.vars_x
1193            xi = symbols('xi', real=True)
1194            
1195            if volume_element is None:
1196                volume_element = 1 / (2 * pi)
1197            
1198            if numerical:
1199                # Numerical integration
1200                p_func = lambdify((x, xi), p, 'numpy')
1201                (x_min, x_max), = x_bounds
1202                (xi_min, xi_max), = xi_bounds
1203                
1204                def integrand(xi_val, x_val):
1205                    return p_func(x_val, xi_val)
1206                
1207                result, error = dblquad(
1208                    integrand,
1209                    x_min, x_max,
1210                    lambda x: xi_min, lambda x: xi_max
1211                )
1212                
1213                result *= float(volume_element)
1214                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1215                return result
1216            
1217            else:
1218                # Symbolic integration
1219                integrand = p * volume_element
1220                
1221                try:
1222                    # Try to integrate over xi first, then x
1223                    integral_xi = integrate(integrand, (xi, -oo, oo))
1224                    integral_x = integrate(integral_xi, (x, -oo, oo))
1225                    return simplify(integral_x)
1226                except:
1227                    print("Warning: Symbolic integration failed. Try numerical=True")
1228                    return integrate(integrand, (xi, -oo, oo), (x, -oo, oo))
1229        
1230        elif self.dim == 2:
1231            x, y = self.vars_x
1232            xi, eta = symbols('xi eta', real=True)
1233            
1234            if volume_element is None:
1235                volume_element = 1 / (4 * pi**2)
1236            
1237            if numerical:
1238                # Numerical integration in 4D
1239                p_func = lambdify((x, y, xi, eta), p, 'numpy')
1240                (x_min, x_max), (y_min, y_max) = x_bounds
1241                (xi_min, xi_max), (eta_min, eta_max) = xi_bounds
1242                
1243                def integrand(eta_val, xi_val, y_val, x_val):
1244                    return p_func(x_val, y_val, xi_val, eta_val)
1245                
1246                result, error = nquad(
1247                    integrand,
1248                    [
1249                        [eta_min, eta_max],
1250                        [xi_min, xi_max],
1251                        [y_min, y_max],
1252                        [x_min, x_max]
1253                    ]
1254                )
1255                
1256                result *= float(volume_element)
1257                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1258                return result
1259            
1260            else:
1261                # Symbolic integration
1262                integrand = p * volume_element
1263                
1264                try:
1265                    # Integrate in order: eta, xi, y, x
1266                    integral_eta = integrate(integrand, (eta, -oo, oo))
1267                    integral_xi = integrate(integral_eta, (xi, -oo, oo))
1268                    integral_y = integrate(integral_xi, (y, -oo, oo))
1269                    integral_x = integrate(integral_y, (x, -oo, oo))
1270                    return simplify(integral_x)
1271                except:
1272                    print("Warning: Symbolic integration failed. Try numerical=True")
1273                    return integrate(
1274                        integrand,
1275                        (eta, -oo, oo), (xi, -oo, oo),
1276                        (y, -oo, oo), (x, -oo, oo)
1277                    )
1278        
1279        else:
1280            raise NotImplementedError("Only 1D and 2D operators are supported")

Compute the semiclassical trace of the pseudo-differential operator.

The trace formula relates the quantum trace of an operator to a phase-space integral of its symbol, providing a fundamental link between classical and quantum mechanics. This implementation supports both symbolic and numerical integration.

Parameters

volume_element : sympy.Expr, optional Custom volume element for the phase space integration. If None, uses the standard Liouville measure dx dξ/(2π)^d. numerical : bool, default=False If True, perform numerical integration over specified bounds. If False, attempt symbolic integration (may fail for complex symbols). x_bounds : tuple of tuples, optional Spatial integration bounds. For 1D: ((x_min, x_max),) For 2D: ((x_min, x_max), (y_min, y_max)) Required if numerical=True. xi_bounds : tuple of tuples, optional Frequency integration bounds. For 1D: ((xi_min, xi_max),) For 2D: ((xi_min, xi_max), (eta_min, eta_max)) Required if numerical=True.

Returns

sympy.Expr or float The trace of the operator. Returns a symbolic expression if numerical=False, or a float if numerical=True.

Notes

  • The semiclassical trace formula states: Tr(P) = (2π)^{-d} ∫∫ p(x,ξ) dx dξ where d is the spatial dimension and p(x,ξ) is the operator symbol.

  • For 1D: Tr(P) = (1/2π) ∫_{-∞}^{∞} ∫_{-∞}^{∞} p(x,ξ) dx dξ

  • For 2D: Tr(P) = (1/4π²) ∫∫∫∫ p(x,y,ξ,η) dx dy dξ dη

  • This formula is exact for trace-class operators and provides an asymptotic approximation for general pseudo-differential operators.

  • Physical interpretation: the trace counts the "number of states" weighted by the observable p(x,ξ).

  • For projection operators (χ_Ω with χ² = χ), the trace gives the dimension of the range, related to the phase space volume of Ω.

  • The factor (2π)^{-d} comes from the quantum normalization of coherent states / Weyl quantization.

def symplectic_flow(self):
1282    def symplectic_flow(self):
1283        """
1284        Compute the Hamiltonian vector field associated with the principal symbol.
1285
1286        This method derives the canonical equations of motion for the phase space variables 
1287        (x, ξ) in 1D or (x, y, ξ, η) in 2D, based on the Hamiltonian formalism. These describe 
1288        how position and frequency variables evolve under the flow generated by the symbol.
1289
1290        Returns
1291        -------
1292        dict
1293            A dictionary containing the components of the Hamiltonian vector field:
1294            - In 1D: keys are 'dx/dt' and 'dxi/dt', corresponding to dx/dt = ∂p/∂ξ and dξ/dt = -∂p/∂x.
1295            - In 2D: keys are 'dx/dt', 'dy/dt', 'dxi/dt', and 'deta/dt', with similar definitions:
1296              dx/dt = ∂p/∂ξ, dy/dt = ∂p/∂η, dξ/dt = -∂p/∂x, dη/dt = -∂p/∂y.
1297
1298        Notes
1299        -----
1300        - The Hamiltonian here is the principal symbol p(x, ξ) itself.
1301        - This flow preserves the symplectic structure of phase space.
1302        """
1303        if self.dim == 1:
1304            x,  = self.vars_x
1305            xi = symbols('xi', real=True)
1306            return {
1307                'dx/dt': diff(self.symbol, xi),
1308                'dxi/dt': -diff(self.symbol, x)
1309            }
1310        elif self.dim == 2:
1311            x, y = self.vars_x
1312            xi, eta = symbols('xi eta', real=True)
1313            return {
1314                'dx/dt': diff(self.symbol, xi),
1315                'dy/dt': diff(self.symbol, eta),
1316                'dxi/dt': -diff(self.symbol, x),
1317                'deta/dt': -diff(self.symbol, y)
1318            }

Compute the Hamiltonian vector field associated with the principal symbol.

This method derives the canonical equations of motion for the phase space variables (x, ξ) in 1D or (x, y, ξ, η) in 2D, based on the Hamiltonian formalism. These describe how position and frequency variables evolve under the flow generated by the symbol.

Returns

dict A dictionary containing the components of the Hamiltonian vector field: - In 1D: keys are 'dx/dt' and 'dxi/dt', corresponding to dx/dt = ∂p/∂ξ and dξ/dt = -∂p/∂x. - In 2D: keys are 'dx/dt', 'dy/dt', 'dxi/dt', and 'deta/dt', with similar definitions: dx/dt = ∂p/∂ξ, dy/dt = ∂p/∂η, dξ/dt = -∂p/∂x, dη/dt = -∂p/∂y.

Notes

  • The Hamiltonian here is the principal symbol p(x, ξ) itself.
  • This flow preserves the symplectic structure of phase space.
def is_elliptic_numerically(self, x_grid, xi_grid, threshold=1e-08):
1320    def is_elliptic_numerically(self, x_grid, xi_grid, threshold=1e-8):
1321        """
1322        Check if the pseudo-differential symbol p(x, ξ) is elliptic over a given grid.
1323    
1324        A symbol is considered elliptic if its magnitude |p(x, ξ)| remains bounded away from zero 
1325        across all points in the spatial-frequency domain. This method evaluates the symbol on a 
1326        grid of spatial and frequency coordinates and checks whether its minimum absolute value 
1327        exceeds a specified threshold.
1328    
1329        Resampling is applied to large grids to prevent excessive memory usage, particularly in 2D.
1330    
1331        Parameters
1332        ----------
1333        x_grid : ndarray
1334            Spatial grid: either a 1D array (x) or a tuple of two 1D arrays (x, y).
1335        xi_grid : ndarray
1336            Frequency grid: either a 1D array (ξ) or a tuple of two 1D arrays (ξ, η).
1337        threshold : float, optional
1338            Minimum acceptable value for |p(x, ξ)|. If the smallest evaluated symbol value falls below this,
1339            the symbol is not considered elliptic.
1340    
1341        Returns
1342        -------
1343        bool
1344            True if the symbol is elliptic on the resampled grid, False otherwise.
1345        """
1346        RESAMPLE_SIZE = 32  # Reduced size to prevent memory explosion
1347        
1348        if self.dim == 1:
1349            x_vals = x_grid
1350            xi_vals = xi_grid
1351            # Resampling if necessary
1352            if len(x_vals) > RESAMPLE_SIZE:
1353                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1354            if len(xi_vals) > RESAMPLE_SIZE:
1355                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1356        
1357            X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1358            symbol_vals = self.p_func(X, XI)
1359        
1360        elif self.dim == 2:
1361            x_vals, y_vals = x_grid
1362            xi_vals, eta_vals = xi_grid
1363        
1364            # Spatial resampling
1365            if len(x_vals) > RESAMPLE_SIZE:
1366                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1367            if len(y_vals) > RESAMPLE_SIZE:
1368                y_vals = np.linspace(y_vals.min(), y_vals.max(), RESAMPLE_SIZE)
1369        
1370            # Frequency resampling
1371            if len(xi_vals) > RESAMPLE_SIZE:
1372                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1373            if len(eta_vals) > RESAMPLE_SIZE:
1374                eta_vals = np.linspace(eta_vals.min(), eta_vals.max(), RESAMPLE_SIZE)
1375        
1376            X, Y, XI, ETA = np.meshgrid(x_vals, y_vals, xi_vals, eta_vals, indexing='ij')
1377            symbol_vals = self.p_func(X, Y, XI, ETA)
1378        
1379        min_abs_val = np.min(np.abs(symbol_vals))
1380        return min_abs_val > threshold

Check if the pseudo-differential symbol p(x, ξ) is elliptic over a given grid.

A symbol is considered elliptic if its magnitude |p(x, ξ)| remains bounded away from zero across all points in the spatial-frequency domain. This method evaluates the symbol on a grid of spatial and frequency coordinates and checks whether its minimum absolute value exceeds a specified threshold.

Resampling is applied to large grids to prevent excessive memory usage, particularly in 2D.

Parameters

x_grid : ndarray Spatial grid: either a 1D array (x) or a tuple of two 1D arrays (x, y). xi_grid : ndarray Frequency grid: either a 1D array (ξ) or a tuple of two 1D arrays (ξ, η). threshold : float, optional Minimum acceptable value for |p(x, ξ)|. If the smallest evaluated symbol value falls below this, the symbol is not considered elliptic.

Returns

bool True if the symbol is elliptic on the resampled grid, False otherwise.

def is_self_adjoint(self, tol=1e-10):
1383    def is_self_adjoint(self, tol=1e-10):
1384        """
1385        Check whether the pseudo-differential operator is formally self-adjoint (Hermitian).
1386
1387        A self-adjoint operator satisfies P = P*, where P* is the formal adjoint of P.
1388        This property is essential for ensuring real-valued eigenvalues and stable evolution 
1389        in quantum mechanics and symmetric wave propagation.
1390
1391        Parameters
1392        ----------
1393        tol : float
1394            Tolerance for symbolic comparison between P and P*. Small numerical differences 
1395            below this threshold are considered equal.
1396
1397        Returns
1398        -------
1399        bool
1400            True if the symbol p(x, ξ) equals its formal adjoint p*(x, ξ) within the given tolerance,
1401            indicating that the operator is self-adjoint.
1402
1403        Notes:
1404        - The formal adjoint is computed via conjugation and asymptotic expansion at infinity in ξ.
1405        - Symbolic simplification is used to verify equality, ensuring robustness against superficial 
1406          expression differences.
1407        """
1408        p = self.symbol
1409        p_star = self.formal_adjoint()
1410        return simplify(p - p_star).equals(0)

Check whether the pseudo-differential operator is formally self-adjoint (Hermitian).

A self-adjoint operator satisfies P = P, where P is the formal adjoint of P. This property is essential for ensuring real-valued eigenvalues and stable evolution in quantum mechanics and symmetric wave propagation.

Parameters

tol : float Tolerance for symbolic comparison between P and P*. Small numerical differences below this threshold are considered equal.

Returns

bool True if the symbol p(x, ξ) equals its formal adjoint p*(x, ξ) within the given tolerance, indicating that the operator is self-adjoint.

Notes:

  • The formal adjoint is computed via conjugation and asymptotic expansion at infinity in ξ.
  • Symbolic simplification is used to verify equality, ensuring robustness against superficial expression differences.
def visualize_fiber(self, x_grid, xi_grid, x0=0.0, y0=0.0):
1412    def visualize_fiber(self, x_grid, xi_grid, x0=0.0, y0=0.0):
1413        """
1414        Plot the cotangent fiber structure at a fixed spatial point (x₀[, y₀]).
1415    
1416        This visualization shows how the symbol p(x, ξ) behaves on the cotangent fiber 
1417        above a fixed spatial point. In microlocal analysis, this provides insight into 
1418        the frequency content of the operator at that location.
1419    
1420        Parameters
1421        ----------
1422        x_grid : ndarray
1423            Spatial grid values (1D) for evaluation in 1D case.
1424        xi_grid : ndarray
1425            Frequency grid values (1D) for evaluation in both 1D and 2D cases.
1426        x0 : float, optional
1427            Fixed x-coordinate of the base point in space (1D or 2D).
1428        y0 : float, optional
1429            Fixed y-coordinate of the base point in space (2D only).
1430    
1431        Notes
1432        -----
1433        - In 1D: Displays |p(x, ξ)| over the (x, ξ) phase plane near the fixed point.
1434        - In 2D: Fixes (x₀, y₀) and evaluates p(x₀, y₀, ξ, η), showing the fiber over that point.
1435        - The color map represents the magnitude of the symbol, highlighting regions where it vanishes or becomes singular.
1436    
1437        Raises
1438        ------
1439        NotImplementedError
1440            If called in 2D with missing or improperly formatted grids.
1441        """
1442        if self.dim == 1:
1443            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1444            symbol_vals = self.p_func(X, XI)
1445            plt.contourf(X, XI, np.abs(symbol_vals), levels=50, cmap='viridis')
1446            plt.colorbar(label='|Symbol|')
1447            plt.xlabel('x (position)')
1448            plt.ylabel('ξ (frequency)')
1449            plt.title('Cotangent Fiber Structure')
1450            plt.show()
1451        elif self.dim == 2:
1452            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, xi_grid)
1453            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1454            plt.contourf(xi_grid, xi_grid, np.abs(symbol_vals), levels=50, cmap='viridis')
1455            plt.colorbar(label='|Symbol|')
1456            plt.xlabel('ξ')
1457            plt.ylabel('η')
1458            plt.title(f'Cotangent Fiber at x={x0}, y={y0}')
1459            plt.show()

Plot the cotangent fiber structure at a fixed spatial point (x₀[, y₀]).

This visualization shows how the symbol p(x, ξ) behaves on the cotangent fiber above a fixed spatial point. In microlocal analysis, this provides insight into the frequency content of the operator at that location.

Parameters

x_grid : ndarray Spatial grid values (1D) for evaluation in 1D case. xi_grid : ndarray Frequency grid values (1D) for evaluation in both 1D and 2D cases. x0 : float, optional Fixed x-coordinate of the base point in space (1D or 2D). y0 : float, optional Fixed y-coordinate of the base point in space (2D only).

Notes

  • In 1D: Displays |p(x, ξ)| over the (x, ξ) phase plane near the fixed point.
  • In 2D: Fixes (x₀, y₀) and evaluates p(x₀, y₀, ξ, η), showing the fiber over that point.
  • The color map represents the magnitude of the symbol, highlighting regions where it vanishes or becomes singular.

Raises

NotImplementedError If called in 2D with missing or improperly formatted grids.

def visualize_symbol_amplitude(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1461    def visualize_symbol_amplitude(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1462        """
1463        Display the modulus |p(x, ξ)| or |p(x, y, ξ₀, η₀)| as a color map.
1464    
1465        This method visualizes the amplitude of the pseudodifferential operator's symbol 
1466        in either 1D or 2D spatial configuration. In 2D, the frequency variables are fixed 
1467        to specified values (ξ₀, η₀) for visualization purposes.
1468    
1469        Parameters
1470        ----------
1471        x_grid, y_grid : ndarray
1472            Spatial grids over which to evaluate the symbol. y_grid is optional and used only in 2D.
1473        xi_grid, eta_grid : ndarray
1474            Frequency grids. In 2D, these define the domain over which the symbol is evaluated,
1475            but the visualization fixes ξ = ξ₀ and η = η₀.
1476        xi0, eta0 : float, optional
1477            Fixed frequency values for slicing in 2D visualization. Defaults to zero.
1478    
1479        Notes
1480        -----
1481        - In 1D: Visualizes |p(x, ξ)| over the (x, ξ) grid.
1482        - In 2D: Visualizes |p(x, y, ξ₀, η₀)| at fixed frequencies ξ₀ and η₀.
1483        - The color intensity represents the magnitude of the symbol, highlighting regions where the symbol is large or small.
1484        """
1485        if self.dim == 1:
1486            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1487            symbol_vals = self.p_func(X, XI) 
1488            plt.pcolormesh(X, XI, np.abs(symbol_vals), shading='auto')
1489            plt.colorbar(label='|Symbol|')
1490            plt.xlabel('x')
1491            plt.ylabel('ξ')
1492            plt.title('Symbol Amplitude |p(x, ξ)|')
1493            plt.show()
1494        elif self.dim == 2:
1495            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
1496            XI = np.full_like(X, xi0)
1497            ETA = np.full_like(Y, eta0)
1498            symbol_vals = self.p_func(X, Y, XI, ETA)
1499            plt.pcolormesh(X, Y, np.abs(symbol_vals), shading='auto')
1500            plt.colorbar(label='|Symbol|')
1501            plt.xlabel('x')
1502            plt.ylabel('y')
1503            plt.title(f'Symbol Amplitude at ξ={xi0}, η={eta0}')
1504            plt.show()

Display the modulus |p(x, ξ)| or |p(x, y, ξ₀, η₀)| as a color map.

This method visualizes the amplitude of the pseudodifferential operator's symbol in either 1D or 2D spatial configuration. In 2D, the frequency variables are fixed to specified values (ξ₀, η₀) for visualization purposes.

Parameters

x_grid, y_grid : ndarray Spatial grids over which to evaluate the symbol. y_grid is optional and used only in 2D. xi_grid, eta_grid : ndarray Frequency grids. In 2D, these define the domain over which the symbol is evaluated, but the visualization fixes ξ = ξ₀ and η = η₀. xi0, eta0 : float, optional Fixed frequency values for slicing in 2D visualization. Defaults to zero.

Notes

  • In 1D: Visualizes |p(x, ξ)| over the (x, ξ) grid.
  • In 2D: Visualizes |p(x, y, ξ₀, η₀)| at fixed frequencies ξ₀ and η₀.
  • The color intensity represents the magnitude of the symbol, highlighting regions where the symbol is large or small.
def visualize_phase(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1506    def visualize_phase(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1507        """
1508        Plot the phase (argument) of the pseudodifferential operator's symbol p(x, ξ) or p(x, y, ξ, η).
1509
1510        This visualization helps in understanding the oscillatory behavior and regularity 
1511        properties of the operator in phase space. The phase is displayed modulo 2π using 
1512        a cyclic colormap ('twilight') to emphasize its periodic nature.
1513
1514        Parameters
1515        ----------
1516        x_grid : ndarray
1517            1D array of spatial coordinates (x).
1518        xi_grid : ndarray
1519            1D array of frequency coordinates (ξ).
1520        y_grid : ndarray, optional
1521            2D spatial grid for y-coordinate (in 2D problems). Default is None.
1522        eta_grid : ndarray, optional
1523            2D frequency grid for η (in 2D problems). Not used directly but kept for API consistency.
1524        xi0 : float, optional
1525            Fixed value of ξ for slicing in 2D visualization. Default is 0.0.
1526        eta0 : float, optional
1527            Fixed value of η for slicing in 2D visualization. Default is 0.0.
1528
1529        Notes:
1530        - In 1D: Displays arg(p(x, ξ)) over the (x, ξ) phase plane.
1531        - In 2D: Displays arg(p(x, y, ξ₀, η₀)) for fixed frequency values (ξ₀, η₀).
1532        - Uses plt.pcolormesh with 'twilight' colormap to represent angles from -π to π.
1533
1534        Raises:
1535        - NotImplementedError: If the spatial dimension is not 1D or 2D.
1536        """
1537        if self.dim == 1:
1538            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1539            symbol_vals = self.p_func(X, XI) 
1540            plt.pcolormesh(X, XI, np.angle(symbol_vals), shading='auto', cmap='twilight')
1541            plt.colorbar(label='arg(Symbol) [rad]')
1542            plt.xlabel('x')
1543            plt.ylabel('ξ')
1544            plt.title('Phase Portrait (arg p(x, ξ))')
1545            plt.show()
1546        elif self.dim == 2:
1547            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
1548            XI = np.full_like(X, xi0)
1549            ETA = np.full_like(Y, eta0)
1550            symbol_vals = self.p_func(X, Y, XI, ETA)
1551            plt.pcolormesh(X, Y, np.angle(symbol_vals), shading='auto', cmap='twilight')
1552            plt.colorbar(label='arg(Symbol) [rad]')
1553            plt.xlabel('x')
1554            plt.ylabel('y')
1555            plt.title(f'Phase Portrait at ξ={xi0}, η={eta0}')
1556            plt.show()

Plot the phase (argument) of the pseudodifferential operator's symbol p(x, ξ) or p(x, y, ξ, η).

This visualization helps in understanding the oscillatory behavior and regularity properties of the operator in phase space. The phase is displayed modulo 2π using a cyclic colormap ('twilight') to emphasize its periodic nature.

Parameters

x_grid : ndarray 1D array of spatial coordinates (x). xi_grid : ndarray 1D array of frequency coordinates (ξ). y_grid : ndarray, optional 2D spatial grid for y-coordinate (in 2D problems). Default is None. eta_grid : ndarray, optional 2D frequency grid for η (in 2D problems). Not used directly but kept for API consistency. xi0 : float, optional Fixed value of ξ for slicing in 2D visualization. Default is 0.0. eta0 : float, optional Fixed value of η for slicing in 2D visualization. Default is 0.0.

Notes:

  • In 1D: Displays arg(p(x, ξ)) over the (x, ξ) phase plane.
  • In 2D: Displays arg(p(x, y, ξ₀, η₀)) for fixed frequency values (ξ₀, η₀).
  • Uses plt.pcolormesh with 'twilight' colormap to represent angles from -π to π.

Raises:

  • NotImplementedError: If the spatial dimension is not 1D or 2D.
def visualize_characteristic_set( self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0, levels=[0.1]):
1558    def visualize_characteristic_set(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0, levels=[1e-1]):
1559        """
1560        Visualize the characteristic set of the pseudo-differential symbol, defined as the approximate zero set p(x, ξ) ≈ 0.
1561    
1562        In microlocal analysis, the characteristic set is the locus of points in phase space (x, ξ) where the symbol p(x, ξ) vanishes,
1563        playing a key role in understanding propagation of singularities.
1564    
1565        Parameters
1566        ----------
1567        x_grid : ndarray
1568            Spatial grid values (1D array) for plotting in 1D or evaluation point in 2D.
1569        xi_grid : ndarray
1570            Frequency variable grid values (1D array) used to construct the frequency domain.
1571        x0 : float, optional
1572            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific x position.
1573        y0 : float, optional
1574            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific y position.
1575    
1576        Notes
1577        -----
1578        - For 1D, this method plots the contour of |p(x, ξ)| = ε with ε = 1e-5 over the (x, ξ) plane.
1579        - For 2D, it evaluates the symbol at fixed (x₀, y₀) and plots the characteristic set in the (ξ, η) frequency plane.
1580        - This visualization helps identify directions of degeneracy or hypoellipticity of the operator.
1581    
1582        Raises
1583        ------
1584        NotImplementedError
1585            If called on a solver with dimensionality other than 1D or 2D.
1586    
1587        Displays
1588        ------
1589        A matplotlib contour plot showing either:
1590            - The characteristic curve in the (x, ξ) phase plane (1D),
1591            - The characteristic surface slice in the (ξ, η) frequency plane at (x₀, y₀) (2D).
1592        """
1593        if self.dim == 1:
1594            x_grid = np.asarray(x_grid)
1595            xi_grid = np.asarray(xi_grid)
1596            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1597            symbol_vals = self.p_func(X, XI) 
1598            plt.contour(X, XI, np.abs(symbol_vals), levels=levels, colors='red')
1599            plt.xlabel('x')
1600            plt.ylabel('ξ')
1601            plt.title('Characteristic Set (p(x, ξ) ≈ 0)')
1602            plt.grid(True)
1603            plt.show()
1604        elif self.dim == 2:
1605            if eta_grid is None:
1606                raise ValueError("eta_grid must be provided for 2D visualization.")
1607            xi_grid = np.asarray(xi_grid)
1608            eta_grid = np.asarray(eta_grid)
1609            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
1610            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1611            plt.contour(xi_grid, eta_grid, np.abs(symbol_vals), levels=levels, colors='red')
1612            plt.xlabel('ξ')
1613            plt.ylabel('η')
1614            plt.title(f'Characteristic Set at x={x0}, y={y0}')
1615            plt.grid(True)
1616            plt.show()
1617        else:
1618            raise NotImplementedError("Only 1D/2D characteristic sets supported.")

Visualize the characteristic set of the pseudo-differential symbol, defined as the approximate zero set p(x, ξ) ≈ 0.

In microlocal analysis, the characteristic set is the locus of points in phase space (x, ξ) where the symbol p(x, ξ) vanishes, playing a key role in understanding propagation of singularities.

Parameters

x_grid : ndarray Spatial grid values (1D array) for plotting in 1D or evaluation point in 2D. xi_grid : ndarray Frequency variable grid values (1D array) used to construct the frequency domain. x0 : float, optional Fixed spatial coordinate in 2D case for evaluating the symbol at a specific x position. y0 : float, optional Fixed spatial coordinate in 2D case for evaluating the symbol at a specific y position.

Notes

  • For 1D, this method plots the contour of |p(x, ξ)| = ε with ε = 1e-5 over the (x, ξ) plane.
  • For 2D, it evaluates the symbol at fixed (x₀, y₀) and plots the characteristic set in the (ξ, η) frequency plane.
  • This visualization helps identify directions of degeneracy or hypoellipticity of the operator.

Raises

NotImplementedError If called on a solver with dimensionality other than 1D or 2D.

Displays

A matplotlib contour plot showing either: - The characteristic curve in the (x, ξ) phase plane (1D), - The characteristic surface slice in the (ξ, η) frequency plane at (x₀, y₀) (2D).

def visualize_characteristic_gradient(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0):
1620    def visualize_characteristic_gradient(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0):
1621        """
1622        Visualize the norm of the gradient of the symbol in phase space.
1623        
1624        This method computes the magnitude of the gradient |∇p| of a pseudo-differential 
1625        symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. The resulting colormap reveals 
1626        regions where the symbol varies rapidly or remains nearly stationary, 
1627        which is particularly useful for analyzing characteristic sets.
1628        
1629        Parameters
1630        ----------
1631        x_grid : numpy.ndarray
1632            1D array of spatial coordinates for the x-direction.
1633        xi_grid : numpy.ndarray
1634            1D array of frequency coordinates (ξ).
1635        y_grid : numpy.ndarray, optional
1636            1D array of spatial coordinates for the y-direction (used in 2D mode). Default is None.
1637        eta_grid : numpy.ndarray, optional
1638            1D array of frequency coordinates (η) for the 2D case. Default is None.
1639        x0 : float, optional
1640            Fixed x-coordinate for evaluating the symbol in 2D. Default is 0.0.
1641        y0 : float, optional
1642            Fixed y-coordinate for evaluating the symbol in 2D. Default is 0.0.
1643        
1644        Returns
1645        -------
1646        None
1647            Displays a 2D colormap of |∇p| over the relevant phase-space domain.
1648        
1649        Notes
1650        -----
1651        - In 1D, the full gradient ∇p = (∂ₓp, ∂ξp) is computed over the (x, ξ) grid.
1652        - In 2D, the gradient ∇p = (∂ξp, ∂ηp) is computed at a fixed spatial point (x₀, y₀) over the (ξ, η) grid.
1653        - Numerical differentiation is performed using `np.gradient`.
1654        - High values of |∇p| indicate rapid variation of the symbol, while low values typically suggest characteristic regions.
1655        """
1656        if self.dim == 1:
1657            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1658            symbol_vals = self.p_func(X, XI)
1659            grad_x = np.gradient(symbol_vals, axis=0)
1660            grad_xi = np.gradient(symbol_vals, axis=1)
1661            grad_norm = np.sqrt(grad_x**2 + grad_xi**2)
1662            plt.pcolormesh(X, XI, grad_norm, cmap='inferno', shading='auto')
1663            plt.colorbar(label='|∇p|')
1664            plt.xlabel('x')
1665            plt.ylabel('ξ')
1666            plt.title('Gradient Norm (High Near Zeros)')
1667            plt.grid(True)
1668            plt.show()
1669        elif self.dim == 2:
1670            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
1671            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1672            grad_xi = np.gradient(symbol_vals, axis=0)
1673            grad_eta = np.gradient(symbol_vals, axis=1)
1674            grad_norm = np.sqrt(np.abs(grad_xi)**2 + np.abs(grad_eta)**2)
1675            plt.pcolormesh(xi_grid, eta_grid, grad_norm, cmap='inferno', shading='auto')
1676            plt.colorbar(label='|∇p|')
1677            plt.xlabel('ξ')
1678            plt.ylabel('η')
1679            plt.title(f'Gradient Norm at x={x0}, y={y0}')
1680            plt.grid(True)
1681            plt.show()

Visualize the norm of the gradient of the symbol in phase space.

This method computes the magnitude of the gradient |∇p| of a pseudo-differential symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. The resulting colormap reveals regions where the symbol varies rapidly or remains nearly stationary, which is particularly useful for analyzing characteristic sets.

Parameters

x_grid : numpy.ndarray 1D array of spatial coordinates for the x-direction. xi_grid : numpy.ndarray 1D array of frequency coordinates (ξ). y_grid : numpy.ndarray, optional 1D array of spatial coordinates for the y-direction (used in 2D mode). Default is None. eta_grid : numpy.ndarray, optional 1D array of frequency coordinates (η) for the 2D case. Default is None. x0 : float, optional Fixed x-coordinate for evaluating the symbol in 2D. Default is 0.0. y0 : float, optional Fixed y-coordinate for evaluating the symbol in 2D. Default is 0.0.

Returns

None Displays a 2D colormap of |∇p| over the relevant phase-space domain.

Notes

  • In 1D, the full gradient ∇p = (∂ₓp, ∂ξp) is computed over the (x, ξ) grid.
  • In 2D, the gradient ∇p = (∂ξp, ∂ηp) is computed at a fixed spatial point (x₀, y₀) over the (ξ, η) grid.
  • Numerical differentiation is performed using np.gradient.
  • High values of |∇p| indicate rapid variation of the symbol, while low values typically suggest characteristic regions.
def plot_hamiltonian_flow( self, x0=0.0, xi0=5.0, y0=0.0, eta0=0.0, tmax=1.0, n_steps=100, show_field=True):
1683    def plot_hamiltonian_flow(self, x0=0.0, xi0=5.0, y0=0.0, eta0=0.0, tmax=1.0, n_steps=100, show_field=True):
1684        """
1685        Integrate and plot the Hamiltonian trajectories of the symbol in phase space.
1686
1687        This method numerically integrates the Hamiltonian vector field derived from 
1688        the operator's symbol to visualize how singularities propagate under the flow. 
1689        It supports both 1D and 2D problems.
1690
1691        Parameters
1692        ----------
1693        x0, xi0 : float
1694            Initial position and frequency (momentum) in 1D.
1695        y0, eta0 : float, optional
1696            Initial position and frequency in 2D; defaults to zero.
1697        tmax : float
1698            Final integration time for the ODE solver.
1699        n_steps : int
1700            Number of time steps used in the integration.
1701
1702        Notes
1703        -----
1704        - The Hamiltonian vector field is obtained from the symplectic flow of the symbol.
1705        - If the field is complex-valued, only its real part is used for integration.
1706        - In 1D, the trajectory is plotted in (x, ξ) phase space.
1707        - In 2D, the spatial trajectory (x(t), y(t)) is shown along with instantaneous 
1708          momentum vectors (ξ(t), η(t)) using a quiver plot.
1709
1710        Raises
1711        ------
1712        NotImplementedError
1713            If the spatial dimension is not 1D or 2D.
1714
1715        Displays
1716        --------
1717        matplotlib plot
1718            Phase space trajectory(ies) showing the evolution of position and momentum 
1719            under the Hamiltonian dynamics.
1720        """
1721        def make_real(expr):
1722            from sympy import re, simplify
1723            expr = expr.doit(deep=True)
1724            return simplify(re(expr))
1725    
1726        H = self.symplectic_flow()
1727    
1728        if any(im(H[k]) != 0 for k in H):
1729            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
1730    
1731        if self.dim == 1:
1732            x, = self.vars_x
1733            xi = symbols('xi', real=True)
1734    
1735            dxdt_expr = make_real(H['dx/dt'])
1736            dxidt_expr = make_real(H['dxi/dt'])
1737    
1738            dxdt = lambdify((x, xi), dxdt_expr, 'numpy')
1739            dxidt = lambdify((x, xi), dxidt_expr, 'numpy')
1740    
1741            def hamilton(t, Y):
1742                x, xi = Y
1743                return [dxdt(x, xi), dxidt(x, xi)]
1744    
1745            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0], t_eval=np.linspace(0, tmax, n_steps))
1746
1747            if sol.status != 0:
1748                print(f"⚠️ Integration warning: {sol.message}")
1749            
1750            n_points = sol.y.shape[1]
1751            if n_points < n_steps:
1752                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
1753                n_steps = n_points
1754
1755            x_vals, xi_vals = sol.y
1756    
1757            plt.plot(x_vals, xi_vals)
1758            plt.xlabel("x")
1759            plt.ylabel("ξ")
1760            plt.title("Hamiltonian Flow in Phase Space (1D)")
1761            plt.grid(True)
1762            plt.show()
1763    
1764        elif self.dim == 2:
1765            x, y = self.vars_x
1766            xi, eta = symbols('xi eta', real=True)
1767    
1768            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
1769            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
1770            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
1771            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
1772    
1773            def hamilton(t, Y):
1774                x, y, xi, eta = Y
1775                return [
1776                    dxdt(x, y, xi, eta),
1777                    dydt(x, y, xi, eta),
1778                    dxidt(x, y, xi, eta),
1779                    detadt(x, y, xi, eta)
1780                ]
1781    
1782            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0], t_eval=np.linspace(0, tmax, n_steps))
1783
1784            if sol.status != 0:
1785                print(f"⚠️ Integration warning: {sol.message}")
1786            
1787            n_points = sol.y.shape[1]
1788            if n_points < n_steps:
1789                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
1790                n_steps = n_points
1791
1792            x_vals, y_vals, xi_vals, eta_vals = sol.y
1793    
1794            plt.plot(x_vals, y_vals, label='Position')
1795            plt.quiver(x_vals, y_vals, xi_vals, eta_vals, scale=20, width=0.003, alpha=0.5, color='r')
1796            
1797            # Vector field of the flow (optional)
1798            if show_field:
1799                X, Y = np.meshgrid(np.linspace(min(x_vals), max(x_vals), 20),
1800                                   np.linspace(min(y_vals), max(y_vals), 20))
1801                XI, ETA = xi0 * np.ones_like(X), eta0 * np.ones_like(Y)
1802                U = dxdt(X, Y, XI, ETA)
1803                V = dydt(X, Y, XI, ETA)
1804                plt.quiver(X, Y, U, V, color='gray', alpha=0.2, scale=30, width=0.002)
1805
1806            plt.xlabel("x")
1807            plt.ylabel("y")
1808            plt.title("Hamiltonian Flow in Phase Space (2D)")
1809            plt.legend()
1810            plt.grid(True)
1811            plt.axis('equal')
1812            plt.show()

Integrate and plot the Hamiltonian trajectories of the symbol in phase space.

This method numerically integrates the Hamiltonian vector field derived from the operator's symbol to visualize how singularities propagate under the flow. It supports both 1D and 2D problems.

Parameters

x0, xi0 : float Initial position and frequency (momentum) in 1D. y0, eta0 : float, optional Initial position and frequency in 2D; defaults to zero. tmax : float Final integration time for the ODE solver. n_steps : int Number of time steps used in the integration.

Notes

  • The Hamiltonian vector field is obtained from the symplectic flow of the symbol.
  • If the field is complex-valued, only its real part is used for integration.
  • In 1D, the trajectory is plotted in (x, ξ) phase space.
  • In 2D, the spatial trajectory (x(t), y(t)) is shown along with instantaneous momentum vectors (ξ(t), η(t)) using a quiver plot.

Raises

NotImplementedError If the spatial dimension is not 1D or 2D.

Displays

matplotlib plot Phase space trajectory(ies) showing the evolution of position and momentum under the Hamiltonian dynamics.

def plot_symplectic_vector_field(self, xlim=(-2, 2), klim=(-5, 5), density=30):
1814    def plot_symplectic_vector_field(self, xlim=(-2, 2), klim=(-5, 5), density=30):
1815        """
1816        Visualize the symplectic vector field (Hamiltonian vector field) associated with the operator's symbol.
1817
1818        The plotted vector field corresponds to (∂_ξ p, -∂_x p), where p(x, ξ) is the principal symbol 
1819        of the pseudo-differential operator. This field governs the bicharacteristic flow in phase space.
1820
1821        Parameters
1822        ----------
1823        xlim : tuple of float
1824            Range for spatial variable x, as (x_min, x_max).
1825        klim : tuple of float
1826            Range for frequency variable ξ, as (ξ_min, ξ_max).
1827        density : int
1828            Number of grid points per axis for the visualization grid.
1829
1830        Raises
1831        ------
1832        NotImplementedError
1833            If called on a 2D operator (currently only 1D implementation available).
1834
1835        Notes
1836        -----
1837        - Only supports one-dimensional operators.
1838        - Uses symbolic differentiation to compute ∂_ξ p and ∂_x p.
1839        - Numerical evaluation is done via lambdify with NumPy backend.
1840        - Visualization uses matplotlib quiver plot to show vector directions.
1841        """
1842        x_vals = np.linspace(*xlim, density)
1843        xi_vals = np.linspace(*klim, density)
1844        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1845
1846        if self.dim != 1:
1847            raise NotImplementedError("Only 1D version implemented.")
1848
1849        x, = self.vars_x
1850        xi = symbols('xi', real=True)
1851        H = self.symplectic_flow()
1852        dxdt = lambdify((x, xi), simplify(H['dx/dt']), 'numpy')
1853        dxidt = lambdify((x, xi), simplify(H['dxi/dt']), 'numpy')
1854
1855        U = dxdt(X, XI)
1856        V = dxidt(X, XI)
1857
1858        plt.quiver(X, XI, U, V, scale=10, width=0.005)
1859        plt.xlabel('x')
1860        plt.ylabel(r'$\xi$')
1861        plt.title("Symplectic Vector Field (1D)")
1862        plt.grid(True)
1863        plt.show()

Visualize the symplectic vector field (Hamiltonian vector field) associated with the operator's symbol.

The plotted vector field corresponds to (∂_ξ p, -∂_x p), where p(x, ξ) is the principal symbol of the pseudo-differential operator. This field governs the bicharacteristic flow in phase space.

Parameters

xlim : tuple of float Range for spatial variable x, as (x_min, x_max). klim : tuple of float Range for frequency variable ξ, as (ξ_min, ξ_max). density : int Number of grid points per axis for the visualization grid.

Raises

NotImplementedError If called on a 2D operator (currently only 1D implementation available).

Notes

  • Only supports one-dimensional operators.
  • Uses symbolic differentiation to compute ∂_ξ p and ∂_x p.
  • Numerical evaluation is done via lambdify with NumPy backend.
  • Visualization uses matplotlib quiver plot to show vector directions.
def visualize_micro_support(self, xlim=(-2, 2), klim=(-10, 10), threshold=0.001, density=300):
1865    def visualize_micro_support(self, xlim=(-2, 2), klim=(-10, 10), threshold=1e-3, density=300):
1866        """
1867        Visualize the micro-support of the operator by plotting the inverse of the symbol magnitude 1 / |p(x, ξ)|.
1868    
1869        The micro-support provides insight into the singularities of a pseudo-differential operator 
1870        in phase space (x, ξ). Regions where |p(x, ξ)| is small correspond to large values in 1/|p(x, ξ)|,
1871        highlighting areas of significant operator influence or singularity.
1872    
1873        Parameters
1874        ----------
1875        xlim : tuple
1876            Spatial domain limits (x_min, x_max).
1877        klim : tuple
1878            Frequency domain limits (ξ_min, ξ_max).
1879        threshold : float
1880            Threshold below which |p(x, ξ)| is considered effectively zero; used for numerical stability.
1881        density : int
1882            Number of grid points along each axis for visualization resolution.
1883    
1884        Raises
1885        ------
1886        NotImplementedError
1887            If called on a solver with dimension greater than 1 (only 1D visualization is supported).
1888    
1889        Notes
1890        -----
1891        - This method evaluates the symbol p(x, ξ) over a grid and plots its reciprocal to emphasize 
1892          regions where the symbol is near zero.
1893        - A small constant (1e-10) is added to the denominator to avoid division by zero.
1894        - The resulting plot helps identify characteristic sets.
1895        """
1896        if self.dim != 1:
1897            raise NotImplementedError("Only 1D micro-support visualization implemented.")
1898
1899        x_vals = np.linspace(*xlim, density)
1900        xi_vals = np.linspace(*klim, density)
1901        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1902        Z = np.abs(self.p_func(X, XI))
1903
1904        plt.contourf(X, XI, 1 / (Z + 1e-10), levels=100, cmap='inferno')
1905        plt.colorbar(label=r'$1/|p(x,\xi)|$')
1906        plt.xlabel('x')
1907        plt.ylabel(r'$\xi$')
1908        plt.title("Micro-Support Estimate (1/|Symbol|)")
1909        plt.show()

Visualize the micro-support of the operator by plotting the inverse of the symbol magnitude 1 / |p(x, ξ)|.

The micro-support provides insight into the singularities of a pseudo-differential operator in phase space (x, ξ). Regions where |p(x, ξ)| is small correspond to large values in 1/|p(x, ξ)|, highlighting areas of significant operator influence or singularity.

Parameters

xlim : tuple Spatial domain limits (x_min, x_max). klim : tuple Frequency domain limits (ξ_min, ξ_max). threshold : float Threshold below which |p(x, ξ)| is considered effectively zero; used for numerical stability. density : int Number of grid points along each axis for visualization resolution.

Raises

NotImplementedError If called on a solver with dimension greater than 1 (only 1D visualization is supported).

Notes

  • This method evaluates the symbol p(x, ξ) over a grid and plots its reciprocal to emphasize regions where the symbol is near zero.
  • A small constant (1e-10) is added to the denominator to avoid division by zero.
  • The resulting plot helps identify characteristic sets.
def group_velocity_field(self, xlim=(-2, 2), klim=(-10, 10), density=30):
1911    def group_velocity_field(self, xlim=(-2, 2), klim=(-10, 10), density=30):
1912        """
1913        Plot the group velocity field ∇_ξ p(x, ξ) for 1D pseudo-differential operators.
1914
1915        The group velocity represents the speed at which waves of different frequencies propagate 
1916        in a dispersive medium. It is defined as the gradient of the symbol p(x, ξ) with respect 
1917        to the frequency variable ξ.
1918
1919        Parameters
1920        ----------
1921        xlim : tuple of float
1922            Spatial domain limits (x-axis).
1923        klim : tuple of float
1924            Frequency domain limits (ξ-axis).
1925        density : int
1926            Number of grid points per axis used for visualization.
1927
1928        Raises
1929        ------
1930        NotImplementedError
1931            If called on a 2D operator, since this visualization is only implemented for 1D.
1932
1933        Notes
1934        -----
1935        - This method visualizes the vector field (∂p/∂ξ) in phase space.
1936        - Used for analyzing wave propagation properties and dispersion relations.
1937        - Requires symbolic expression self.expr depending on x and ξ.
1938        """
1939        if self.dim != 1:
1940            raise NotImplementedError("Only 1D group velocity visualization implemented.")
1941
1942        x, = self.vars_x
1943        xi = symbols('xi', real=True)
1944        dp_dxi = diff(self.symbol, xi)
1945        grad_func = lambdify((x, xi), dp_dxi, 'numpy')
1946
1947        x_vals = np.linspace(*xlim, density)
1948        xi_vals = np.linspace(*klim, density)
1949        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1950        V = grad_func(X, XI)
1951
1952        plt.quiver(X, XI, np.ones_like(V), V, scale=10, width=0.004)
1953        plt.xlabel('x')
1954        plt.ylabel(r'$\xi$')
1955        plt.title("Group Velocity Field (1D)")
1956        plt.grid(True)
1957        plt.show()

Plot the group velocity field ∇_ξ p(x, ξ) for 1D pseudo-differential operators.

The group velocity represents the speed at which waves of different frequencies propagate in a dispersive medium. It is defined as the gradient of the symbol p(x, ξ) with respect to the frequency variable ξ.

Parameters

xlim : tuple of float Spatial domain limits (x-axis). klim : tuple of float Frequency domain limits (ξ-axis). density : int Number of grid points per axis used for visualization.

Raises

NotImplementedError If called on a 2D operator, since this visualization is only implemented for 1D.

Notes

  • This method visualizes the vector field (∂p/∂ξ) in phase space.
  • Used for analyzing wave propagation properties and dispersion relations.
  • Requires symbolic expression self.expr depending on x and ξ.
def animate_singularity( self, xi0=5.0, eta0=0.0, x0=0.0, y0=0.0, tmax=4.0, n_frames=100, projection=None):
1959    def animate_singularity(self, xi0=5.0, eta0=0.0, x0=0.0, y0=0.0,
1960                            tmax=4.0, n_frames=100, projection=None):
1961        """
1962        Animate the propagation of a singularity under the Hamiltonian flow.
1963
1964        This method visualizes how a singularity (x₀, y₀, ξ₀, η₀) evolves in phase space 
1965        according to the Hamiltonian dynamics induced by the principal symbol of the operator.
1966        The animation integrates the Hamiltonian equations of motion and supports various projections:
1967        position (x-y), frequency (ξ-η), or mixed phase space coordinates.
1968
1969        Parameters
1970        ----------
1971        xi0, eta0 : float
1972            Initial frequency components (ξ₀, η₀).
1973        x0, y0 : float
1974            Initial spatial coordinates (x₀, y₀).
1975        tmax : float
1976            Total time of integration (final animation time).
1977        n_frames : int
1978            Number of frames in the resulting animation.
1979        projection : str or None
1980            Type of projection to display:
1981                - 'position' : x vs y (or x alone in 1D)
1982                - 'frequency': ξ vs η (or ξ alone in 1D)
1983                - 'phase'    : mixed coordinates like x vs ξ or x vs η
1984                If None, defaults to 'phase' in 1D and 'position' in 2D.
1985
1986        Returns
1987        -------
1988        matplotlib.animation.FuncAnimation
1989            Animation object that can be displayed interactively in Jupyter notebooks or saved as a video.
1990
1991        Notes
1992        -----
1993        - In 1D, only one spatial and one frequency variable are used.
1994        - Complex-valued Hamiltonian fields are truncated to their real parts for integration.
1995        - Trajectories are shown with both instantaneous position (dot) and full path (dashed line).
1996        """
1997        rc('animation', html='jshtml')
1998    
1999        def make_real(expr):
2000            from sympy import re, simplify
2001            expr = expr.doit(deep=True)
2002            return simplify(re(expr))
2003  
2004        H = self.symplectic_flow()
2005
2006        H = {k: v.doit(deep=True) for k, v in H.items()}
2007
2008        print("H = ", H)
2009    
2010        if any(im(H[k]) != 0 for k in H):
2011            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
2012    
2013        if self.dim == 1:
2014            x, = self.vars_x
2015            xi = symbols('xi', real=True)
2016    
2017            dxdt = lambdify((x, xi), make_real(H['dx/dt']), 'numpy')
2018            dxidt = lambdify((x, xi), make_real(H['dxi/dt']), 'numpy')
2019    
2020            def hamilton(t, Y):
2021                x, xi = Y
2022                return [dxdt(x, xi), dxidt(x, xi)]
2023    
2024            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0],
2025                            t_eval=np.linspace(0, tmax, n_frames))
2026            
2027            if sol.status != 0:
2028                print(f"⚠️ Integration warning: {sol.message}")
2029            
2030            n_points = sol.y.shape[1]
2031            if n_points < n_frames:
2032                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2033                n_frames = n_points
2034
2035            x_vals, xi_vals = sol.y
2036    
2037            if projection is None:
2038                projection = 'phase'
2039    
2040            fig, ax = plt.subplots()
2041            point, = ax.plot([], [], 'ro')
2042            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2043    
2044            if projection == 'phase':
2045                ax.set_xlabel('x')
2046                ax.set_ylabel(r'$\xi$')
2047                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2048                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2049    
2050                def update(i):
2051                    point.set_data([x_vals[i]], [xi_vals[i]])
2052                    traj.set_data(x_vals[:i+1], xi_vals[:i+1])
2053                    return point, traj
2054    
2055            elif projection == 'position':
2056                ax.set_xlabel('x')
2057                ax.set_ylabel('x')
2058                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2059                ax.set_ylim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2060    
2061                def update(i):
2062                    point.set_data([x_vals[i]], [x_vals[i]])
2063                    traj.set_data(x_vals[:i+1], x_vals[:i+1])
2064                    return point, traj
2065    
2066            elif projection == 'frequency':
2067                ax.set_xlabel(r'$\xi$')
2068                ax.set_ylabel(r'$\xi$')
2069                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2070                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2071    
2072                def update(i):
2073                    point.set_data([xi_vals[i]], [xi_vals[i]])
2074                    traj.set_data(xi_vals[:i+1], xi_vals[:i+1])
2075                    return point, traj
2076    
2077            else:
2078                raise ValueError("Invalid projection mode")
2079    
2080            ax.set_title(f"1D Singularity Flow ({projection})")
2081            ax.grid(True)
2082            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2083            plt.close(fig)
2084            return ani
2085    
2086        elif self.dim == 2:
2087            x, y = self.vars_x
2088            xi, eta = symbols('xi eta', real=True)
2089    
2090            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
2091            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
2092            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
2093            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
2094    
2095            def hamilton(t, Y):
2096                x, y, xi, eta = Y
2097                return [
2098                    dxdt(x, y, xi, eta),
2099                    dydt(x, y, xi, eta),
2100                    dxidt(x, y, xi, eta),
2101                    detadt(x, y, xi, eta)
2102                ]
2103    
2104            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0],
2105                            t_eval=np.linspace(0, tmax, n_frames))
2106
2107            if sol.status != 0:
2108                print(f"⚠️ Integration warning: {sol.message}")
2109            
2110            n_points = sol.y.shape[1]
2111            if n_points < n_frames:
2112                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2113                n_frames = n_points
2114                
2115            x_vals, y_vals, xi_vals, eta_vals = sol.y
2116    
2117            if projection is None:
2118                projection = 'position'
2119    
2120            fig, ax = plt.subplots()
2121            point, = ax.plot([], [], 'ro')
2122            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2123    
2124            if projection == 'position':
2125                ax.set_xlabel('x')
2126                ax.set_ylabel('y')
2127                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2128                ax.set_ylim(np.min(y_vals) - 1, np.max(y_vals) + 1)
2129    
2130                def update(i):
2131                    point.set_data([x_vals[i]], [y_vals[i]])
2132                    traj.set_data(x_vals[:i+1], y_vals[:i+1])
2133                    return point, traj
2134    
2135            elif projection == 'frequency':
2136                ax.set_xlabel(r'$\xi$')
2137                ax.set_ylabel(r'$\eta$')
2138                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2139                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2140    
2141                def update(i):
2142                    point.set_data([xi_vals[i]], [eta_vals[i]])
2143                    traj.set_data(xi_vals[:i+1], eta_vals[:i+1])
2144                    return point, traj
2145    
2146            elif projection == 'phase':
2147                ax.set_xlabel('x')
2148                ax.set_ylabel(r'$\eta$')
2149                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2150                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2151    
2152                def update(i):
2153                    point.set_data([x_vals[i]], [eta_vals[i]])
2154                    traj.set_data(x_vals[:i+1], eta_vals[:i+1])
2155                    return point, traj
2156    
2157            else:
2158                raise ValueError("Invalid projection mode")
2159    
2160            ax.set_title(f"2D Singularity Flow ({projection})")
2161            ax.grid(True)
2162            ax.axis('equal')
2163            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2164            plt.close(fig)
2165            return ani

Animate the propagation of a singularity under the Hamiltonian flow.

This method visualizes how a singularity (x₀, y₀, ξ₀, η₀) evolves in phase space according to the Hamiltonian dynamics induced by the principal symbol of the operator. The animation integrates the Hamiltonian equations of motion and supports various projections: position (x-y), frequency (ξ-η), or mixed phase space coordinates.

Parameters

xi0, eta0 : float Initial frequency components (ξ₀, η₀). x0, y0 : float Initial spatial coordinates (x₀, y₀). tmax : float Total time of integration (final animation time). n_frames : int Number of frames in the resulting animation. projection : str or None Type of projection to display: - 'position' : x vs y (or x alone in 1D) - 'frequency': ξ vs η (or ξ alone in 1D) - 'phase' : mixed coordinates like x vs ξ or x vs η If None, defaults to 'phase' in 1D and 'position' in 2D.

Returns

matplotlib.animation.FuncAnimation Animation object that can be displayed interactively in Jupyter notebooks or saved as a video.

Notes

  • In 1D, only one spatial and one frequency variable are used.
  • Complex-valued Hamiltonian fields are truncated to their real parts for integration.
  • Trajectories are shown with both instantaneous position (dot) and full path (dashed line).
def interactive_symbol_analysis( pseudo_op, xlim=(-2, 2), ylim=(-2, 2), xi_range=(0.1, 5), eta_range=(-5, 5), density=100):
2167    def interactive_symbol_analysis(pseudo_op,
2168                                    xlim=(-2, 2), ylim=(-2, 2),
2169                                    xi_range=(0.1, 5), eta_range=(-5, 5),
2170                                    density=100):
2171        """
2172        Launch an interactive dashboard for symbol exploration using ipywidgets.
2173    
2174        This function provides a user-friendly interface to visualize various aspects of the pseudo-differential operator's symbol.
2175        It supports multiple visualization modes in both 1D and 2D, including group velocity fields, micro-support estimates,
2176        symplectic vector fields, symbol amplitude/phase, cotangent fiber structure, characteristic sets and Hamiltonian flows.
2177    
2178        Parameters
2179        ----------
2180        pseudo_op : PseudoDifferentialOperator
2181            The pseudo-differential operator whose symbol is to be analyzed interactively.
2182        xlim, ylim : tuple of float
2183            Spatial domain limits along x and y axes respectively.
2184        xi_range, eta_range : tuple
2185            Frequency domain limits along ξ and η axes respectively.
2186        density : int
2187            Number of points per axis used to construct the evaluation grid. Controls resolution.
2188    
2189        Notes
2190        -----
2191        - In 1D mode, sliders control the fixed frequency (ξ₀) and spatial position (x₀).
2192        - In 2D mode, additional sliders control the second frequency component (η₀) and second spatial coordinate (y₀).
2193        - Visualization updates dynamically as parameters are adjusted via sliders or dropdown menus.
2194        - Supported visualization modes:
2195            'Symbol Amplitude'           : |p(x,ξ)| or |p(x,y,ξ,η)|
2196            'Symbol Phase'               : arg(p(x,ξ)) or similar in 2D
2197            'Micro-Support (1/|p|)'      : Reciprocal of symbol magnitude
2198            'Cotangent Fiber'            : Structure of symbol over frequency space at fixed x
2199            'Characteristic Set'         : Zero set approximation {p ≈ 0}
2200            'Characteristic Gradient'    : |∇p(x, ξ)| or |∇p(x₀, y₀, ξ, η)|
2201            'Group Velocity Field'       : ∇_ξ p(x,ξ) or ∇_{ξ,η} p(x,y,ξ,η)
2202            'Symplectic Vector Field'    : (∇_ξ p, -∇_x p) or similar in 2D
2203            'Hamiltonian Flow'           : Trajectories generated by the Hamiltonian vector field
2204    
2205        Raises
2206        ------
2207        NotImplementedError
2208            If the spatial dimension is not 1D or 2D.
2209    
2210        Prints
2211        ------
2212        Interactive matplotlib figures with dynamic updates based on widget inputs.
2213        """
2214        dim = pseudo_op.dim
2215        expr = pseudo_op.expr
2216        vars_x = pseudo_op.vars_x
2217    
2218        mode_selector_1D = Dropdown(
2219            options=[
2220                'Symbol Amplitude',
2221                'Symbol Phase',
2222                'Micro-Support (1/|p|)',
2223                'Cotangent Fiber',
2224                'Characteristic Set',
2225                'Characteristic Gradient',
2226                'Group Velocity Field',
2227                'Symplectic Vector Field',
2228                'Hamiltonian Flow',
2229            ],
2230            value='Symbol Amplitude',
2231            description='Mode:'
2232        )
2233
2234        mode_selector_2D = Dropdown(
2235            options=[
2236                'Symbol Amplitude',
2237                'Symbol Phase',
2238                'Micro-Support (1/|p|)',
2239                'Cotangent Fiber',
2240                'Characteristic Set',
2241                'Characteristic Gradient',
2242                'Symplectic Vector Field',
2243                'Hamiltonian Flow',
2244            ],
2245            value='Symbol Amplitude',
2246            description='Mode:'
2247        )
2248    
2249        x_vals = np.linspace(*xlim, density)
2250        if dim == 2:
2251            y_vals = np.linspace(*ylim, density)
2252    
2253        if dim == 1:
2254            x, = vars_x
2255            xi = symbols('xi', real=True)
2256            grad_func = lambdify((x, xi), diff(expr, xi), 'numpy')
2257            symplectic_func = lambdify((x, xi), [diff(expr, xi), -diff(expr, x)], 'numpy')
2258            symbol_func = lambdify((x, xi), expr, 'numpy')
2259
2260            xi_slider = FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2261            x_slider = FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2262    
2263            def plot_1d(mode, xi0, x0):
2264                X = x_vals[:, None]
2265    
2266                if mode == 'Group Velocity Field':
2267                    V = grad_func(X, xi0)
2268                    plt.quiver(X, V, np.ones_like(V), V, scale=10, width=0.004)
2269                    plt.xlabel('x')
2270                    plt.title(f'Group Velocity Field at ξ={xi0:.2f}')
2271    
2272                elif mode == 'Micro-Support (1/|p|)':
2273                    Z = 1 / (np.abs(symbol_func(X, xi0)) + 1e-10)
2274                    plt.plot(x_vals, Z)
2275                    plt.xlabel('x')
2276                    plt.title(f'Micro-Support (1/|p|) at ξ={xi0:.2f}')
2277    
2278                elif mode == 'Symplectic Vector Field':
2279                    U, V = symplectic_func(X, xi0)
2280                    plt.quiver(X, V, U, V, scale=10, width=0.004)
2281                    plt.xlabel('x')
2282                    plt.title(f'Symplectic Field at ξ={xi0:.2f}')
2283    
2284                elif mode == 'Symbol Amplitude':
2285                    Z = np.abs(symbol_func(X, xi0))
2286                    plt.plot(x_vals, Z)
2287                    plt.xlabel('x')
2288                    plt.title(f'Symbol Amplitude |p(x,ξ)| at ξ={xi0:.2f}')
2289    
2290                elif mode == 'Symbol Phase':
2291                    Z = np.angle(symbol_func(X, xi0))
2292                    plt.plot(x_vals, Z)
2293                    plt.xlabel('x')
2294                    plt.title(f'Symbol Phase arg(p(x,ξ)) at ξ={xi0:.2f}')
2295    
2296                elif mode == 'Cotangent Fiber':
2297                    pseudo_op.visualize_fiber(x_vals, np.linspace(*xi_range, density), x0=x0)
2298    
2299                elif mode == 'Characteristic Set':
2300                    pseudo_op.visualize_characteristic_set(x_vals, np.linspace(*xi_range, density), x0=x0)
2301    
2302                elif mode == 'Characteristic Gradient':
2303                    pseudo_op.visualize_characteristic_gradient(x_vals, np.linspace(*xi_range, density), x0=x0)
2304    
2305                elif mode == 'Hamiltonian Flow':
2306                    pseudo_op.plot_hamiltonian_flow(x0=x0, xi0=xi0)
2307    
2308            # --- Dynamic container for sliders ---
2309            controls_box = VBox([mode_selector_1D, xi_slider, x_slider])
2310            # --- Function to adjust visible sliders based on mode ---
2311            def update_controls(change):
2312                mode = change['new']
2313                # modes that depend only on xi and eta
2314                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)',
2315                            'Group Velocity Field', 'Symplectic Vector Field']:
2316                    controls_box.children = [mode_selector_1D, xi_slider]
2317                # modes that require xi and x
2318                elif mode in ['Hamiltonian Flow']:
2319                    controls_box.children = [mode_selector_1D, xi_slider, x_slider]
2320                # modes that require nothing
2321                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2322                    controls_box.children = [mode_selector_1D]
2323            mode_selector_1D.observe(update_controls, names='value')
2324            update_controls({'new': mode_selector_1D.value}) 
2325            # --- Interactive binding ---
2326            out = interactive_output(plot_1d, {'mode': mode_selector_1D, 'xi0': xi_slider, 'x0': x_slider})
2327            display(VBox([controls_box, out]))
2328
2329        elif dim == 2:
2330            x, y = vars_x
2331            xi, eta = symbols('xi eta', real=True)
2332            symplectic_func = lambdify((x, y, xi, eta), [diff(expr, xi), diff(expr, eta)], 'numpy')
2333            symbol_func = lambdify((x, y, xi, eta), expr, 'numpy')
2334
2335            xi_slider=FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2336            eta_slider=FloatSlider(min=eta_range[0], max=eta_range[1], step=0.1, value=1.0, description='η₀')
2337            x_slider=FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2338            y_slider=FloatSlider(min=ylim[0], max=ylim[1], step=0.1, value=0.0, description='y₀')
2339    
2340            def plot_2d(mode, xi0, eta0, x0, y0):
2341                X, Y = np.meshgrid(x_vals, y_vals, indexing='ij')
2342    
2343                if mode == 'Micro-Support (1/|p|)':
2344                    Z = 1 / (np.abs(symbol_func(X, Y, xi0, eta0)) + 1e-10)
2345                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='inferno')
2346                    plt.colorbar(label='1/|p|')
2347                    plt.xlabel('x')
2348                    plt.ylabel('y')
2349                    plt.title(f'Micro-Support at ξ={xi0:.2f}, η={eta0:.2f}')
2350    
2351                elif mode == 'Symplectic Vector Field':
2352                    U, V = symplectic_func(X, Y, xi0, eta0)
2353                    plt.quiver(X, Y, U, V, scale=10, width=0.004)
2354                    plt.xlabel('x')
2355                    plt.ylabel('y')
2356                    plt.title(f'Symplectic Field at ξ={xi0:.2f}, η={eta0:.2f}')
2357    
2358                elif mode == 'Symbol Amplitude':
2359                    Z = np.abs(symbol_func(X, Y, xi0, eta0))
2360                    plt.pcolormesh(X, Y, Z, shading='auto')
2361                    plt.colorbar(label='|p(x,y,ξ,η)|')
2362                    plt.xlabel('x')
2363                    plt.ylabel('y')
2364                    plt.title(f'Symbol Amplitude at ξ={xi0:.2f}, η={eta0:.2f}')
2365    
2366                elif mode == 'Symbol Phase':
2367                    Z = np.angle(symbol_func(X, Y, xi0, eta0))
2368                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='twilight')
2369                    plt.colorbar(label='arg(p)')
2370                    plt.xlabel('x')
2371                    plt.ylabel('y')
2372                    plt.title(f'Symbol Phase at ξ={xi0:.2f}, η={eta0:.2f}')
2373    
2374                elif mode == 'Cotangent Fiber':
2375                    pseudo_op.visualize_fiber(np.linspace(*xi_range, density), np.linspace(*eta_range, density),
2376                                              x0=x0, y0=y0)
2377    
2378                elif mode == 'Characteristic Set':
2379                    pseudo_op.visualize_characteristic_set(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2380                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2381    
2382                elif mode == 'Characteristic Gradient':
2383                    pseudo_op.visualize_characteristic_gradient(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2384                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2385    
2386                elif mode == 'Hamiltonian Flow':
2387                    pseudo_op.plot_hamiltonian_flow(x0=x0, y0=y0, xi0=xi0, eta0=eta0)
2388                    
2389            # --- Dynamic container for sliders ---
2390            controls_box = VBox([mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider])
2391            # --- Function to adjust visible sliders based on mode ---
2392            def update_controls(change):
2393                mode = change['new']
2394                # modes that depend only on xi
2395                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)', 'Symplectic Vector Field']:
2396                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider]
2397                # modes that require xi, eta, x and y
2398                elif mode in ['Hamiltonian Flow']:
2399                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider]
2400                # modes that require x and y
2401                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2402                    controls_box.children = [mode_selector_2D, x_slider, y_slider]
2403            mode_selector_2D.observe(update_controls, names='value')
2404            update_controls({'new': mode_selector_2D.value}) 
2405            # --- Interactive binding ---
2406            out = interactive_output(plot_2d, {'mode': mode_selector_2D, 'xi0': xi_slider, 'eta0': eta_slider, 'x0': x_slider, 'y0': y_slider})
2407            display(VBox([controls_box, out]))

Launch an interactive dashboard for symbol exploration using ipywidgets.

This function provides a user-friendly interface to visualize various aspects of the pseudo-differential operator's symbol. It supports multiple visualization modes in both 1D and 2D, including group velocity fields, micro-support estimates, symplectic vector fields, symbol amplitude/phase, cotangent fiber structure, characteristic sets and Hamiltonian flows.

Parameters

pseudo_op : PseudoDifferentialOperator The pseudo-differential operator whose symbol is to be analyzed interactively. xlim, ylim : tuple of float Spatial domain limits along x and y axes respectively. xi_range, eta_range : tuple Frequency domain limits along ξ and η axes respectively. density : int Number of points per axis used to construct the evaluation grid. Controls resolution.

Notes

  • In 1D mode, sliders control the fixed frequency (ξ₀) and spatial position (x₀).
  • In 2D mode, additional sliders control the second frequency component (η₀) and second spatial coordinate (y₀).
  • Visualization updates dynamically as parameters are adjusted via sliders or dropdown menus.
  • Supported visualization modes: 'Symbol Amplitude' : |p(x,ξ)| or |p(x,y,ξ,η)| 'Symbol Phase' : arg(p(x,ξ)) or similar in 2D 'Micro-Support (1/|p|)' : Reciprocal of symbol magnitude 'Cotangent Fiber' : Structure of symbol over frequency space at fixed x 'Characteristic Set' : Zero set approximation {p ≈ 0} 'Characteristic Gradient' : |∇p(x, ξ)| or |∇p(x₀, y₀, ξ, η)| 'Group Velocity Field' : ∇_ξ p(x,ξ) or ∇_{ξ,η} p(x,y,ξ,η) 'Symplectic Vector Field' : (∇_ξ p, -∇_x p) or similar in 2D 'Hamiltonian Flow' : Trajectories generated by the Hamiltonian vector field

Raises

NotImplementedError If the spatial dimension is not 1D or 2D.

Prints

Interactive matplotlib figures with dynamic updates based on widget inputs.

class PDESolver:
  28class PDESolver:
  29    """
  30    A partial differential equation (PDE) solver based on **spectral methods** using Fourier transforms.
  31
  32    This solver supports symbolic specification of PDEs via SymPy and numerical solution using high-order spectral techniques. 
  33    It is designed for both **linear and nonlinear time-dependent PDEs**, as well as **stationary pseudo-differential problems**.
  34    
  35    Key Features:
  36    -------------
  37    - Symbolic PDE parsing using SymPy expressions
  38    - 1D and 2D spatial domains with periodic boundary conditions
  39    - Fourier-based spectral discretization with dealiasing
  40    - Temporal integration schemes:
  41        - Default exponential time stepping
  42        - ETD-RK4 (Exponential Time Differencing Runge-Kutta of 4th order)
  43    - Nonlinear terms handled through pseudo-spectral evaluation
  44    - Built-in tools for:
  45        - Visualization of solutions and error surfaces
  46        - Symbol analysis of linear and pseudo-differential operators
  47        - Microlocal analysis (e.g., Hamiltonian flows)
  48        - CFL condition checking and numerical stability diagnostics
  49
  50    Supported Operators:
  51    --------------------
  52    - Linear differential and pseudo-differential operators
  53    - Nonlinear terms up to second order in derivatives
  54    - Symbolic operator composition and adjoints
  55    - Asymptotic inversion of elliptic operators for stationary problems
  56
  57    Example Usage:
  58    --------------
  59    >>> from PDESolver import *
  60    >>> u = Function('u')
  61    >>> t, x = symbols('t x')
  62    >>> eq = Eq(diff(u(t, x), t), diff(u(t, x), x, 2) + u(t, x)**2)
  63    >>> def initial(x): return np.sin(x)
  64    >>> solver = PDESolver(eq)
  65    >>> solver.setup(Lx=2*np.pi, Nx=128, Lt=1.0, Nt=1000, initial_condition=initial)
  66    >>> solver.solve()
  67    >>> ani = solver.animate()
  68    >>> HTML(ani.to_jshtml())  # Display animation in Jupyter notebook
  69    """
  70    def __init__(self, equation, time_scheme='default', dealiasing_ratio=2/3):
  71        """
  72        Initialize the PDE solver with a given equation.
  73
  74        This method analyzes the input partial differential equation (PDE), 
  75        identifies the unknown function and its dependencies, determines whether 
  76        the problem is stationary or time-dependent, and prepares symbolic and 
  77        numerical structures for solving in spectral space.
  78
  79        Supported features:
  80        
  81        - 1D and 2D problems
  82        - Time-dependent and stationary equations
  83        - Linear and nonlinear terms
  84        - Pseudo-differential operators via `psiOp`
  85        - Source terms and boundary conditions
  86
  87        The equation is parsed to extract linear, nonlinear, source, and 
  88        pseudo-differential components. Symbolic manipulation is used to derive 
  89        the Fourier representation of linear operators when applicable.
  90
  91        Parameters
  92        ----------
  93        equation : sympy.Eq 
  94            The PDE expressed as a SymPy equation.
  95        time_scheme : str
  96            Temporal integration scheme: 
  97                - 'default' for exponential 
  98                - time-stepping or 'ETD-RK4' for fourth-order exponential 
  99                - time differencing Runge–Kutta.
 100        dealiasing_ratio : float
 101            Fraction of high-frequency modes to zero out 
 102            during dealiasing (e.g., 2/3 for standard truncation).
 103
 104        Attributes initialized:
 105        
 106        - self.u: the unknown function (e.g., u(t, x))
 107        - self.dim: spatial dimension (1 or 2)
 108        - self.spatial_vars: list of spatial variables (e.g., [x] or [x, y])
 109        - self.is_stationary: boolean indicating if the problem is stationary
 110        - self.linear_terms: dictionary mapping derivative orders to coefficients
 111        - self.nonlinear_terms: list of nonlinear expressions
 112        - self.source_terms: list of source functions
 113        - self.pseudo_terms: list of pseudo-differential operator expressions
 114        - self.has_psi: boolean indicating presence of pseudo-differential operators
 115        - self.fft / self.ifft: appropriate FFT routines based on spatial dimension
 116        - self.kx, self.ky: symbolic wavenumber variables for Fourier space
 117
 118        Raises:
 119            ValueError: If the equation does not contain exactly one unknown function,
 120                        if unsupported dimensions are detected, or invalid dependencies.
 121        """
 122        self.time_scheme = time_scheme # 'default'  or 'ETD-RK4'
 123        self.dealiasing_ratio = dealiasing_ratio
 124        
 125        print("\n*********************************")
 126        print("* Partial differential equation *")
 127        print("*********************************\n")
 128        pprint(equation, num_columns=NUM_COLS)
 129        
 130        # Extract symbols and function from the equation
 131        functions = equation.atoms(Function)
 132        
 133        # Ignore the wrappers psiOp and Op
 134        excluded_wrappers = {'psiOp', 'Op'}
 135        
 136        # Extract the candidate fonctions (excluding wrappers)
 137        candidate_functions = [
 138            f for f in functions 
 139            if f.func.__name__ not in excluded_wrappers
 140        ]
 141        
 142        # Keep only user functions (u(x), u(x, t), etc.)
 143        candidate_functions = [
 144            f for f in functions
 145            if isinstance(f, AppliedUndef)
 146        ]
 147        
 148        # Stationary detection: no dependence on t
 149        self.is_stationary = all(
 150            not any(str(arg) == 't' for arg in f.args)
 151            for f in candidate_functions
 152        )
 153        
 154        if len(candidate_functions) != 1:
 155            print("candidate_functions :", candidate_functions)
 156            raise ValueError("The equation must contain exactly one unknown function")
 157        
 158        self.u = candidate_functions[0]
 159
 160        self.u_eq = self.u
 161
 162        args = self.u.args
 163        
 164        if self.is_stationary:
 165            if len(args) not in (1, 2):
 166                raise ValueError("Stationary problems must depend on 1 or 2 spatial variables")
 167            self.spatial_vars = args
 168        else:
 169            if len(args) < 2 or len(args) > 3:
 170                raise ValueError("The function must depend on t and at least one spatial variable (x [, y])")
 171            self.t = args[0]
 172            self.spatial_vars = args[1:]
 173
 174        self.dim = len(self.spatial_vars)
 175        if self.dim == 1:
 176            self.x = self.spatial_vars[0]
 177            self.y = None
 178        elif self.dim == 2:
 179            self.x, self.y = self.spatial_vars
 180        else:
 181            raise ValueError("Only 1D and 2D problems are supported.")
 182
 183        if self.dim == 1:
 184            self.fft = partial(fft, workers=FFT_WORKERS)
 185            self.ifft = partial(ifft, workers=FFT_WORKERS)
 186        else:
 187            self.fft = partial(fft2, workers=FFT_WORKERS)
 188            self.ifft = partial(ifft2, workers=FFT_WORKERS)
 189            
 190        # Parse the equation
 191        self.linear_terms = {}
 192        self.nonlinear_terms = []
 193        self.symbol_terms = []
 194        self.source_terms = []
 195        self.pseudo_terms = []
 196        self.temporal_order = 0  # Order of the temporal derivative
 197        self.linear_terms, self.nonlinear_terms, self.symbol_terms, self.source_terms, self.pseudo_terms = self.parse_equation(equation)
 198        # flag : pseudo‑differential operator present ?
 199        self.has_psi = bool(self.pseudo_terms)
 200        if self.has_psi:
 201            print('⚠️  Pseudo‑differential operator detected: all other linear terms have been rejected.')
 202            self.is_spatial = False
 203            for coeff, expr in self.pseudo_terms:
 204                if expr.has(self.x) or (self.dim == 2 and expr.has(self.y)):
 205                    self.is_spatial = True
 206                    break
 207    
 208        if self.dim == 1:
 209            self.kx = symbols('kx')
 210        elif self.dim == 2:
 211            self.kx, self.ky = symbols('kx ky')
 212    
 213        # Compute linear operator
 214        if not self.is_stationary:
 215            self.compute_linear_operator()
 216        else:
 217            self.psi_ops = []
 218            for coeff, sym_expr in self.pseudo_terms:
 219                psi = PseudoDifferentialOperator(sym_expr, self.spatial_vars, self.u, mode='symbol')
 220                self.psi_ops.append((coeff, psi))
 221
 222    def parse_equation(self, equation):
 223        """
 224        Parse the PDE to separate linear and nonlinear terms, symbolic operators (Op), 
 225        source terms, and pseudo-differential operators (psiOp).
 226    
 227        This method rewrites the input equation in standard form (lhs - rhs = 0),
 228        expands it, and classifies each term into one of the following categories:
 229        
 230        - Linear terms involving derivatives or the unknown function u
 231        - Nonlinear terms (products with u, powers of u, etc.)
 232        - Symbolic pseudo-differential operators (Op)
 233        - Source terms (independent of u)
 234        - Pseudo-differential operators (psiOp)
 235    
 236        Parameters
 237            equation (sympy.Eq): The partial differential equation to be analyzed. 
 238                                 Can be provided as an Eq object or a sympy expression.
 239    
 240        Returns:
 241            tuple: A 5-tuple containing:
 242            
 243                - linear_terms (dict): Mapping from derivative/function to coefficient.
 244                - nonlinear_terms (list): List of terms classified as nonlinear.
 245                - symbol_terms (list): List of (coefficient, symbolic operator) pairs.
 246                - source_terms (list): List of terms independent of the unknown function.
 247                - pseudo_terms (list): List of (coefficient, pseudo-differential symbol) pairs.
 248    
 249        Notes:
 250            - If `psiOp` is present in the equation, expansion is skipped for safety.
 251            - When `psiOp` is used, only nonlinear terms, source terms, and possibly 
 252              a time derivative are allowed; other linear terms and symbolic operators 
 253              (Op) are forbidden.
 254            - Classification logic includes:
 255                - Detection of nonlinear structures like products or powers of u
 256                - Mixed terms involving both u and its derivatives
 257                - External symbolic operators (Op) and pseudo-differential operators (psiOp)
 258        """
 259        def is_nonlinear_term(term, u_func):
 260            # If the term contains functions (Abs, sin, exp, ...) applied to u
 261            if term.has(u_func):
 262                for sub in preorder_traversal(term):
 263                    if isinstance(sub, Function) and sub.has(u_func) and sub.func != u_func.func:
 264                        return True
 265            # If the term contains a nonlinear power of u
 266            if term.has(Pow):
 267                for pow_term in term.atoms(Pow):
 268                    if pow_term.base == u_func and pow_term.exp != 1:
 269                        return True
 270            # If the term is a product containing u and its derivative
 271            if term.func == Mul:
 272                factors = term.args
 273                has_u = any((f.has(u_func) and not isinstance(f, Derivative) for f in factors))
 274                has_derivative = any((isinstance(f, Derivative) and f.expr.func == u_func.func for f in factors))
 275                if has_u and has_derivative:
 276                    return True
 277            return False
 278    
 279        print("\n********************")
 280        print("* Equation parsing *")
 281        print("********************\n")
 282    
 283        if isinstance(equation, Eq):
 284            lhs = equation.lhs - equation.rhs
 285        else:
 286            lhs = equation
 287    
 288        print(f"\nEquation rewritten in standard form: {lhs}")
 289        if lhs.has(psiOp):
 290            print("⚠️ psiOp detected: skipping expansion for safety")
 291            lhs_expanded = lhs
 292        else:
 293            lhs_expanded = expand(lhs)
 294    
 295        print(f"\nExpanded equation: {lhs_expanded}")
 296    
 297        linear_terms = {}
 298        nonlinear_terms = []
 299        symbol_terms = []
 300        source_terms = []
 301        pseudo_terms = []
 302    
 303        for term in lhs_expanded.as_ordered_terms():
 304            print(f"Analyzing term: {term}")
 305    
 306            if isinstance(term, psiOp):
 307                expr = term.args[0]
 308                pseudo_terms.append((1, expr))
 309                print("  --> Classified as pseudo linear term (psiOp)")
 310                continue
 311    
 312            # Otherwise, look for psiOp inside (general case)
 313            if term.has(psiOp):
 314                psiops = term.atoms(psiOp)
 315                for psi in psiops:
 316                    try:
 317                        coeff = simplify(term / psi)
 318                        expr = psi.args[0]
 319                        pseudo_terms.append((coeff, expr))
 320                        print("  --> Classified as pseudo linear term (psiOp)")
 321                    except Exception as e:
 322                        print(f"  ⚠️ Failed to extract psiOp coefficient in term: {term}")
 323                        print(f"     Reason: {e}")
 324                        nonlinear_terms.append(term)
 325                        print("  --> Fallback: classified as nonlinear")
 326                continue
 327    
 328            if term.has(Op):
 329                ops = term.atoms(Op)
 330                for op in ops:
 331                    coeff = term / op
 332                    expr = op.args[0]
 333                    symbol_terms.append((coeff, expr))
 334                    print("  --> Classified as symbolic linear term (Op)")
 335                continue
 336    
 337            if is_nonlinear_term(term, self.u):
 338                nonlinear_terms.append(term)
 339                print("  --> Classified as nonlinear")
 340                continue
 341    
 342            derivs = term.atoms(Derivative)
 343            if derivs:
 344                deriv = derivs.pop()
 345                coeff = term / deriv
 346                linear_terms[deriv] = linear_terms.get(deriv, 0) + coeff
 347                print(f"  Derivative found: {deriv}")
 348                print("  --> Classified as linear")
 349            elif self.u in term.atoms(Function):
 350                coeff = term.as_coefficients_dict().get(self.u, 1)
 351                linear_terms[self.u] = linear_terms.get(self.u, 0) + coeff
 352                print("  --> Classified as linear")
 353            else:
 354                source_terms.append(term)
 355                print("  --> Classified as source term")
 356    
 357        print(f"Final linear terms: {linear_terms}")
 358        print(f"Final nonlinear terms: {nonlinear_terms}")
 359        print(f"Symbol terms: {symbol_terms}")
 360        print(f"Pseudo terms: {pseudo_terms}")
 361        print(f"Source terms: {source_terms}")
 362    
 363        if pseudo_terms:
 364            # Check if a time derivative is present among the linear terms
 365            has_time_derivative = any(
 366                isinstance(term, Derivative) and self.t in [v for v, _  in term.variable_count]
 367                for term in linear_terms
 368            )
 369            # Extract non-temporal linear terms
 370            invalid_linear_terms = {
 371                term: coeff for term, coeff in linear_terms.items()
 372                if not (
 373                    isinstance(term, Derivative)
 374                    and self.t in [v for v, _  in term.variable_count]
 375                )
 376                and term != self.u  # exclusion of the simple u term (without derivative)
 377            }
 378    
 379            if invalid_linear_terms or symbol_terms:
 380                raise ValueError(
 381                    "When psiOp is used, only nonlinear terms, source terms, "
 382                    "and possibly a time derivative are allowed. "
 383                    "Other linear terms and Ops are forbidden."
 384                )
 385    
 386        return linear_terms, nonlinear_terms, symbol_terms, source_terms, pseudo_terms
 387
 388
 389    def compute_linear_operator(self):
 390        """
 391        Compute the symbolic Fourier representation L(k) of the linear operator 
 392        derived from the linear part of the PDE.
 393    
 394        This method constructs a dispersion relation by applying each symbolic derivative
 395        to a plane wave exp(i(k·x - ωt)) and extracting the resulting expression.
 396        It handles arbitrary derivative combinations and includes symbolic and
 397        pseudo-differential terms.
 398    
 399        Steps:
 400        -------
 401        1. Construct a plane wave φ(x, t) = exp(i(k·x - ωt)).
 402        2. Apply each term from self.linear_terms to φ.
 403        3. Normalize by φ and simplify to obtain L(k).
 404        4. Include symbolic terms (e.g., psiOp) if present.
 405        5. Detect the temporal order from the dispersion relation.
 406        6. Build the numerical function L(k) via lambdify.
 407    
 408        Sets:
 409        -----
 410        - self.L_symbolic : sympy.Expr
 411            Symbolic form of L(k).
 412        - self.L : callable
 413            Numerical function of L(kx[, ky]).
 414        - self.omega : callable or None
 415            Frequency root ω(k), if available.
 416        - self.temporal_order : int
 417            Order of time derivatives detected.
 418        - self.psi_ops : list of (coeff, PseudoDifferentialOperator)
 419            Pseudo-differential terms present in the equation.
 420    
 421        Raises:
 422        -------
 423        ValueError if the dimension is unsupported or the dispersion relation fails.
 424        """
 425        print("\n*******************************")
 426        print("* Linear operator computation *")
 427        print("*******************************\n")
 428    
 429        # --- Step 1: symbolic variables ---
 430        omega = symbols("omega")
 431        if self.dim == 1:
 432            kvars = [symbols("kx")]
 433            space_vars = [self.x]
 434        elif self.dim == 2:
 435            kvars = symbols("kx ky")
 436            space_vars = [self.x, self.y]
 437        else:
 438            raise ValueError("Only 1D and 2D are supported.")
 439    
 440        kdict = dict(zip(space_vars, kvars))
 441        self.k_symbols = kvars
 442    
 443        # Plane wave expression
 444        phase = sum(k * x for k, x in zip(kvars, space_vars)) - omega * self.t
 445        plane_wave = exp(I * phase)
 446    
 447        # --- Step 2: build lhs expression from linear terms ---
 448        lhs = 0
 449        for deriv, coeff in self.linear_terms.items():
 450            if isinstance(deriv, Derivative):
 451                total_factor = 1
 452                for var, n in deriv.variable_count:
 453                    if var == self.t:
 454                        total_factor *= (-I * omega)**n
 455                    elif var in kdict:
 456                        total_factor *= (I * kdict[var])**n
 457                    else:
 458                        raise ValueError(f"Unknown variable {var} in derivative")
 459                lhs += coeff * total_factor * plane_wave
 460            elif deriv == self.u:
 461                lhs += coeff * plane_wave
 462            else:
 463                raise ValueError(f"Unsupported linear term: {deriv}")
 464    
 465        # --- Step 3: dispersion relation ---
 466        equation = simplify(lhs / plane_wave)
 467        print("\nCharacteristic equation before symbol treatment:")
 468        pprint(equation, num_columns=NUM_COLS)
 469
 470        print("\n--- Symbolic symbol analysis ---")
 471        symb_omega = 0
 472        symb_k = 0
 473        
 474        for coeff, symbol in self.symbol_terms:
 475            if symbol.has(omega):
 476                # Ajouter directement les termes dépendant de omega
 477                symb_omega += coeff * symbol
 478            elif any(symbol.has(k) for k in self.k_symbols):
 479                 symb_k += coeff * symbol.subs(dict(zip(symbol.free_symbols, self.k_symbols)))
 480
 481        print(f"symb_omega: {symb_omega}")
 482        print(f"symb_k: {symb_k}")
 483        
 484        equation = equation + symb_omega + symb_k         
 485
 486        print("\nRaw characteristic equation:")
 487        pprint(equation, num_columns=NUM_COLS)
 488
 489        # Temporal derivative order detection
 490        try:
 491            poly_eq = Eq(equation, 0)
 492            poly = poly_eq.lhs.as_poly(omega)
 493            self.temporal_order = poly.degree() if poly else 0
 494        except Exception as e:
 495            warnings.warn(f"Could not determine temporal order: {e}", RuntimeWarning)
 496            self.temporal_order = 0
 497        print(f"Temporal order from dispersion relation: {self.temporal_order}")
 498        print('self.pseudo_terms = ', self.pseudo_terms)
 499        if self.pseudo_terms:
 500            coeff_time = 1
 501            for term, coeff in self.linear_terms.items():
 502                if isinstance(term, Derivative) and any(var == self.t for var, _  in term.variable_count):
 503                    coeff_time = coeff
 504                    print(f"✅ Time derivative coefficient detected: {coeff_time}")
 505            self.psi_ops = []
 506            for coeff, sym_expr in self.pseudo_terms:
 507                # expr est le Sympy expr. différentiel, var_x la liste [x] ou [x,y]
 508                psi = PseudoDifferentialOperator(sym_expr / coeff_time, self.spatial_vars, self.u, mode='symbol')
 509                
 510                self.psi_ops.append((coeff, psi))
 511        else:
 512            dispersion = solve(Eq(equation, 0), omega)
 513            if not dispersion:
 514                raise ValueError("No solution found for omega")
 515            print("\n--- Solutions found ---")
 516            pprint(dispersion, num_columns=NUM_COLS)
 517        
 518            if self.temporal_order == 2:
 519                omega_expr = simplify(sqrt(dispersion[0]**2))
 520                self.omega_symbolic = omega_expr
 521                self.omega = lambdify(self.k_symbols, omega_expr, "numpy")
 522                self.L_symbolic = -omega_expr**2
 523            else:
 524                self.L_symbolic = -I * dispersion[0]
 525        
 526        
 527            self.L = lambdify(self.k_symbols, self.L_symbolic, "numpy")
 528  
 529            print("\n--- Final linear operator ---")
 530            pprint(self.L_symbolic, num_columns=NUM_COLS)   
 531
 532    def linear_rhs(self, u, is_v=False):
 533        """
 534        Apply the linear operator (in Fourier space) to the field u or v.
 535
 536        Parameters
 537        ----------
 538        u : np.ndarray
 539            Input solution array.
 540        is_v : bool
 541            Whether to apply the operator to v instead of u.
 542
 543        Returns
 544        -------
 545        np.ndarray
 546            Result of applying the linear operator.
 547        """
 548        if self.dim == 1:
 549            self.symbol_u = np.array(self.L(self.KX), dtype=np.complex128)
 550            self.symbol_v = self.symbol_u  # même opérateur pour u et v
 551        elif self.dim == 2:
 552            self.symbol_u = np.array(self.L(self.KX, self.KY), dtype=np.complex128)
 553            self.symbol_v = self.symbol_u
 554        u_hat = self.fft(u)
 555        u_hat *= self.symbol_v if is_v else self.symbol_u
 556        u_hat *= self.dealiasing_mask
 557        return self.ifft(u_hat)
 558
 559    def setup(self, Lx, Ly=None, Nx=None, Ny=None, Lt=1.0, Nt=100, boundary_condition='periodic',
 560              initial_condition=None, initial_velocity=None, n_frames=100, plot=True):
 561        """
 562        Configure the spatial/temporal grid and initialize the solution field.
 563    
 564        This method sets up the computational domain, initializes spatial and temporal grids,
 565        applies boundary conditions, and prepares symbolic and numerical operators.
 566        It also performs essential analyses such as:
 567        
 568            - CFL condition verification (for stability)
 569            - Symbol analysis (e.g., dispersion relation, regularity)
 570            - Wave propagation analysis for second-order equations
 571    
 572        If pseudo-differential operators (ψOp) are present, symbolic analysis is skipped
 573        in favor of interactive exploration via `interactive_symbol_analysis`.
 574    
 575        Parameters
 576        ----------
 577        Lx : float
 578            Size of the spatial domain along x-axis.
 579        Ly : float, optional
 580            Size of the spatial domain along y-axis (for 2D problems).
 581        Nx : int
 582            Number of spatial points along x-axis.
 583        Ny : int, optional
 584            Number of spatial points along y-axis (for 2D problems).
 585        Lt : float, default=1.0
 586            Total simulation time.
 587        Nt : int, default=100
 588            Number of time steps.
 589        initial_condition : callable
 590            Function returning the initial state u(x, 0) or u(x, y, 0).
 591        initial_velocity : callable, optional
 592            Function returning the initial time derivative ∂ₜu(x, 0) or ∂ₜu(x, y, 0),
 593            required for second-order equations.
 594        n_frames : int, default=100
 595            Number of time frames to store during simulation for visualization or output.
 596    
 597        Raises
 598        ------
 599        ValueError
 600            If mandatory parameters are missing (e.g., Nx not given in 1D, Ly/Ny not given in 2D).
 601    
 602        Notes
 603        -----
 604        - The spatial discretization assumes periodic boundary conditions by default.
 605        - Fourier transforms are computed using real-to-complex FFTs (`scipy.fft.fft`, `fft2`).
 606        - Frequency arrays (`KX`, `KY`) are defined following standard spectral conventions.
 607        - Dealiasing is applied using a sharp cutoff filter at a fraction of the maximum frequency.
 608        - For second-order equations, initial acceleration is derived from the governing operator.
 609        - Symbolic analysis includes plotting of the symbol's real/imaginary/absolute values
 610          and dispersion relation.
 611    
 612        See Also
 613        --------
 614        setup_1D : Sets up internal variables for one-dimensional problems.
 615        setup_2D : Sets up internal variables for two-dimensional problems.
 616        initialize_conditions : Applies initial data and enforces compatibility.
 617        check_cfl_condition : Verifies time step against stability constraints.
 618        plot_symbol : Visualizes the linear operator’s symbol in frequency space.
 619        analyze_wave_propagation : Analyzes group velocity.
 620        interactive_symbol_analysis : Interactive tools for ψOp-based equations.
 621        """
 622        
 623        # Temporal parameters
 624        self.Lt, self.Nt = Lt, Nt
 625        self.dt = Lt / Nt
 626        self.n_frames = n_frames
 627        self.frames = []
 628        self.initial_condition = initial_condition
 629        self.boundary_condition = boundary_condition
 630        self.plot = plot
 631
 632        if self.boundary_condition == 'dirichlet' and not self.has_psi:
 633            raise ValueError(
 634                "Dirichlet boundary conditions require the equation to be defined via a pseudo-differential operator (psiOp). "
 635                "Please provide an equation involving psiOp for non-periodic boundary treatment."
 636            )
 637    
 638        # Dimension checks
 639        if self.dim == 1:
 640            if Nx is None:
 641                raise ValueError("Nx must be specified in 1D.")
 642            self.setup_1D(Lx, Nx)
 643        else:
 644            if None in (Ly, Ny):
 645                raise ValueError("In 2D, Ly and Ny must be provided.")
 646            self.setup_2D(Lx, Ly, Nx, Ny)
 647    
 648        # Initialization of solution and velocities
 649        if not self.is_stationary:
 650            self.initialize_conditions(initial_condition, initial_velocity)
 651            
 652        # Symbol analysis if present
 653        if self.has_psi:
 654            print("⚠️ For psiOp, use interactive_symbol_analysis.")
 655        else:
 656            if self.L_symbolic == 0:
 657                print("⚠️ Linear operator is null.")
 658            else:
 659                self.check_cfl_condition()
 660                self.check_symbol_conditions()
 661                if plot:
 662                	self.plot_symbol()
 663                	if self.temporal_order == 2:
 664                		self.analyze_wave_propagation()
 665
 666    def setup_1D(self, Lx, Nx):
 667        """
 668        Configure internal variables for one-dimensional (1D) problems.
 669    
 670        This private method initializes spatial and frequency grids, applies dealiasing,
 671        and prepares either pseudo-differential symbols or linear operators for use in time evolution.
 672        
 673        It assumes periodic boundary conditions and uses real-to-complex FFT conventions.
 674        The spatial domain is centered at zero: [-Lx/2, Lx/2].
 675    
 676        Parameters
 677        ----------
 678        Lx : float
 679            Physical size of the spatial domain along the x-axis.
 680        Nx : int
 681            Number of grid points in the x-direction.
 682    
 683        Attributes Set
 684        --------------
 685        - self.Lx : float
 686            Size of the spatial domain.
 687        - self.Nx : int
 688            Number of spatial points.
 689        - self.x_grid : np.ndarray
 690            1D array of spatial coordinates.
 691        - self.X : np.ndarray
 692            Alias to `self.x_grid`, used in physical space computations.
 693        - self.kx : np.ndarray
 694            Array of wavenumbers corresponding to the Fourier transform.
 695        - self.KX : np.ndarray
 696            Alias to `self.kx`, used in frequency space computations.
 697        - self.dealiasing_mask : np.ndarray
 698            Boolean mask used to suppress aliased frequencies during nonlinear calculations.
 699        - self.exp_L : np.ndarray
 700            Exponential of the linear operator scaled by time step: exp(L(k) · dt).
 701        - self.omega_val : np.ndarray
 702            Frequency values ω(k) = Re[√(L(k))] used in second-order time stepping.
 703        - self.cos_omega_dt, self.sin_omega_dt : np.ndarray
 704            Cosine and sine of ω(k)·dt for dispersive propagation.
 705        - self.inv_omega : np.ndarray
 706            Inverse of ω(k), used to avoid division-by-zero in time stepping.
 707    
 708        Notes
 709        -----
 710        - Frequencies are computed using `scipy.fft.fftfreq` and then shifted to center zero frequency.
 711        - Dealiasing is applied using a sharp cutoff filter based on `self.dealiasing_ratio`.
 712        - If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via `prepare_symbol_tables`.
 713        - For second-order equations, the dispersion relation ω(k) is extracted from the linear operator L(k).
 714    
 715        See Also
 716        --------
 717        setup_2D : Equivalent setup for two-dimensional problems.
 718        prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation.
 719        setup_omega_terms : Sets up terms involving ω(k) for second-order evolution.
 720        """
 721        self.Lx, self.Nx = Lx, Nx
 722        self.x_grid = np.linspace(-Lx/2, Lx/2, Nx, endpoint=False)
 723        self.X = self.x_grid
 724        self.kx = 2 * np.pi * fftfreq(Nx, d=Lx / Nx)
 725        self.KX = self.kx
 726    
 727        # Dealiasing mask
 728        k_max = self.dealiasing_ratio * np.max(np.abs(self.kx))
 729        self.dealiasing_mask = (np.abs(self.KX) <= k_max)
 730    
 731        # Preparation of symbol or linear operator
 732        if self.has_psi:
 733            self.prepare_symbol_tables()
 734        else:
 735            L_vals = np.array(self.L(self.KX), dtype=np.complex128)
 736            self.exp_L = np.exp(L_vals * self.dt)
 737            if self.temporal_order == 2:
 738                omega_val = self.omega(self.KX)
 739                self.setup_omega_terms(omega_val)
 740    
 741    def setup_2D(self, Lx, Ly, Nx, Ny):
 742        """
 743        Configure internal variables for two-dimensional (2D) problems.
 744    
 745        This private method initializes spatial and frequency grids, applies dealiasing,
 746        and prepares either pseudo-differential symbols or linear operators for use in time evolution.
 747        
 748        It assumes periodic boundary conditions and uses real-to-complex FFT conventions.
 749        The spatial domain is centered at zero: [-Lx/2, Lx/2] × [-Ly/2, Ly/2].
 750    
 751        Parameters
 752        ----------
 753        Lx : float
 754            Physical size of the spatial domain along the x-axis.
 755        Ly : float
 756            Physical size of the spatial domain along the y-axis.
 757        Nx : int
 758            Number of grid points along the x-direction.
 759        Ny : int
 760            Number of grid points along the y-direction.
 761    
 762        Attributes Set
 763        --------------
 764        - self.Lx, self.Ly : float
 765            Size of the spatial domain in each direction.
 766        - self.Nx, self.Ny : int
 767            Number of spatial points in each direction.
 768        - self.x_grid, self.y_grid : np.ndarray
 769            1D arrays of spatial coordinates in x and y directions.
 770        - self.X, self.Y : np.ndarray
 771            2D meshgrids of spatial coordinates for physical space computations.
 772        - self.kx, self.ky : np.ndarray
 773            Arrays of wavenumbers corresponding to Fourier transforms in x and y directions.
 774        - self.KX, self.KY : np.ndarray
 775            Meshgrids of wavenumbers used in frequency space computations.
 776        - self.dealiasing_mask : np.ndarray
 777            Boolean mask used to suppress aliased frequencies during nonlinear calculations.
 778        - self.exp_L : np.ndarray
 779            Exponential of the linear operator scaled by time step: exp(L(kx, ky) · dt).
 780        - self.omega_val : np.ndarray
 781            Frequency values ω(kx, ky) = Re[√(L(kx, ky))] used in second-order time stepping.
 782        - self.cos_omega_dt, self.sin_omega_dt : np.ndarray
 783            Cosine and sine of ω(kx, ky)·dt for dispersive propagation.
 784        - self.inv_omega : np.ndarray
 785            Inverse of ω(kx, ky), used to avoid division-by-zero in time stepping.
 786    
 787        Notes
 788        -----
 789        - Frequencies are computed using `scipy.fft.fftfreq` and then shifted to center zero frequency.
 790        - Dealiasing is applied using a sharp cutoff filter based on `self.dealiasing_ratio`.
 791        - If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via `prepare_symbol_tables`.
 792        - For second-order equations, the dispersion relation ω(kx, ky) is extracted from the linear operator L(kx, ky).
 793    
 794        See Also
 795        --------
 796        setup_1D : Equivalent setup for one-dimensional problems.
 797        prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation.
 798        setup_omega_terms : Sets up terms involving ω(kx, ky) for second-order evolution.
 799        """
 800        self.Lx, self.Ly = Lx, Ly
 801        self.Nx, self.Ny = Nx, Ny
 802        self.x_grid = np.linspace(-Lx/2, Lx/2, Nx, endpoint=False)
 803        self.y_grid = np.linspace(-Ly/2, Ly/2, Ny, endpoint=False)
 804        self.X, self.Y = np.meshgrid(self.x_grid, self.y_grid, indexing='ij')
 805        self.kx = 2 * np.pi * fftfreq(Nx, d=Lx / Nx)
 806        self.ky = 2 * np.pi * fftfreq(Ny, d=Ly / Ny)
 807        self.KX, self.KY = np.meshgrid(self.kx, self.ky, indexing='ij')
 808    
 809        # Dealiasing mask
 810        kx_max = self.dealiasing_ratio * np.max(np.abs(self.kx))
 811        ky_max = self.dealiasing_ratio * np.max(np.abs(self.ky))
 812        self.dealiasing_mask = (np.abs(self.KX) <= kx_max) & (np.abs(self.KY) <= ky_max)
 813    
 814        # Preparation of symbol or linear operator
 815        if self.has_psi:
 816            self.prepare_symbol_tables()
 817        else:
 818            L_vals = self.L(self.KX, self.KY)
 819            self.exp_L = np.exp(L_vals * self.dt)
 820            if self.temporal_order == 2:
 821                omega_val = self.omega(self.KX, self.KY)
 822                self.setup_omega_terms(omega_val)
 823    
 824    def setup_omega_terms(self, omega_val):
 825        """
 826        Initialize terms derived from the angular frequency ω for time evolution.
 827    
 828        This private method precomputes and stores key trigonometric and inverse quantities
 829        based on the dispersion relation ω(k), used in second-order time integration schemes.
 830        
 831        These values are essential for solving wave-like equations with dispersive behavior:
 832            cos(ω·dt), sin(ω·dt), 1/ω
 833        
 834        The inverse frequency is computed safely to avoid division by zero.
 835    
 836        Parameters
 837        ----------
 838        omega_val : np.ndarray
 839            Array of angular frequency values ω(k) evaluated at discrete wavenumbers.
 840            Can be one-dimensional (1D) or two-dimensional (2D) depending on spatial dimension.
 841    
 842        Attributes Set
 843        --------------
 844        - self.omega_val : np.ndarray
 845            Copy of the input angular frequency array.
 846        - self.cos_omega_dt : np.ndarray
 847            Cosine of ω(k) multiplied by time step: cos(ω(k) · dt).
 848        - self.sin_omega_dt : np.ndarray
 849            Sine of ω(k) multiplied by time step: sin(ω(k) · dt).
 850        - self.inv_omega : np.ndarray
 851            Inverse of ω(k), with zeros where ω(k) == 0 to avoid division by zero.
 852    
 853        Notes
 854        -----
 855        - This method is typically called during setup when solving second-order PDEs
 856          involving dispersive waves (e.g., Klein-Gordon, Schrödinger, or water wave equations).
 857        - The safe computation of 1/ω ensures numerical stability even when low frequencies are present.
 858        - These precomputed arrays are used in spectral propagators for accurate time stepping.
 859    
 860        See Also
 861        --------
 862        setup_1D : Sets up internal variables for one-dimensional problems.
 863        setup_2D : Sets up internal variables for two-dimensional problems.
 864        solve : Time integration using the computed frequency terms.
 865        """
 866        self.omega_val = omega_val
 867        self.cos_omega_dt = np.cos(omega_val * self.dt)
 868        self.sin_omega_dt = np.sin(omega_val * self.dt)
 869        self.inv_omega = np.zeros_like(omega_val)
 870        nonzero = omega_val != 0
 871        self.inv_omega[nonzero] = 1.0 / omega_val[nonzero]
 872
 873    def evaluate_source_at_t0(self):
 874        """
 875        Evaluate source terms at initial time t = 0 over the spatial grid.
 876    
 877        This private method computes the total contribution of all source terms at the initial time,
 878        evaluated across the entire spatial domain. It supports both one-dimensional (1D) and
 879        two-dimensional (2D) configurations.
 880    
 881        Returns
 882        -------
 883        np.ndarray
 884            A numpy array representing the evaluated source term at t=0:
 885            - In 1D: Shape (Nx,), evaluated at each x in `self.x_grid`.
 886            - In 2D: Shape (Nx, Ny), evaluated at each (x, y) pair in the grid.
 887    
 888        Notes
 889        -----
 890        - The symbolic expressions in `self.source_terms` are substituted with numerical values at t=0.
 891        - In 1D, each term is evaluated at (t=0, x=x_val).
 892        - In 2D, each term is evaluated at (t=0, x=x_val, y=y_val).
 893        - Evaluated using SymPy's `evalf()` to ensure numeric conversion.
 894        - This method assumes that the source terms have already been lambdified or are compatible with symbolic substitution.
 895    
 896        See Also
 897        --------
 898        setup : Initializes the spatial grid and source terms.
 899        solve : Uses this evaluation during the first time step.
 900        """
 901        if self.dim == 1:
 902            # Evaluation on the 1D spatial grid
 903            return np.array([
 904                sum(term.subs(self.t, 0).subs(self.x, x_val).evalf()
 905                    for term in self.source_terms)
 906                for x_val in self.x_grid
 907            ], dtype=np.float64)
 908        else:
 909            # Evaluation on the 2D spatial grid
 910            return np.array([
 911                [sum(term.subs({self.t: 0, self.x: x_val, self.y: y_val}).evalf()
 912                      for term in self.source_terms)
 913                 for y_val in self.y_grid]
 914                for x_val in self.x_grid
 915            ], dtype=np.float64)
 916    
 917    def initialize_conditions(self, initial_condition, initial_velocity):
 918        """
 919        Initialize the solution and velocity fields at t = 0.
 920    
 921        This private method sets up the initial state of the solution `u_prev` and, if applicable,
 922        the time derivative (velocity) `v_prev` for second-order evolution equations.
 923        
 924        For second-order equations, it also computes the backward-in-time value `u_prev2`
 925        needed by the Leap-Frog method. The acceleration at t = 0 is computed from:
 926            ∂ₜ²u = L(u) + N(u) + f(x, t=0)
 927        where L is the linear operator, N is the nonlinear term, and f is the source term.
 928    
 929        Parameters
 930        ----------
 931        initial_condition : callable
 932            Function returning the initial condition u(x, 0) or u(x, y, 0).
 933        initial_velocity : callable or None
 934            Function returning the initial velocity ∂ₜu(x, 0) or ∂ₜu(x, y, 0). Required for
 935            second-order equations; ignored otherwise.
 936    
 937        Raises
 938        ------
 939        ValueError
 940            If `initial_velocity` is not provided for second-order equations.
 941    
 942        Notes
 943        -----
 944        - Applies periodic boundary conditions after setting initial data.
 945        - Stores a copy of the initial state in `self.frames` for visualization/output.
 946        - In second-order systems, initializes `self.u_prev2` using a Taylor expansion:
 947          u_prev2 = u_prev - dt * v_prev + 0.5 * dt² * (∂ₜ²u)
 948    
 949        See Also
 950        --------
 951        apply_boundary : Enforces periodic boundary conditions on the solution field.
 952        psiOp_apply : Computes pseudo-differential operator action for acceleration.
 953        linear_rhs : Evaluates linear part of the equation in Fourier space.
 954        apply_nonlinear : Handles nonlinear terms with spectral differentiation.
 955        evaluate_source_at_t0 : Evaluates source terms at the initial time.
 956        """
 957        # Initial condition
 958        if self.dim == 1:
 959            self.u_prev = initial_condition(self.X)
 960        else:
 961            self.u_prev = initial_condition(self.X, self.Y)
 962        self.apply_boundary(self.u_prev)
 963    
 964        # Initial velocity (second order)
 965        if self.temporal_order == 2:
 966            if initial_velocity is None:
 967                raise ValueError("Initial velocity is required for second-order equations.")
 968            if self.dim == 1:
 969                self.v_prev = initial_velocity(self.X)
 970            else:
 971                self.v_prev = initial_velocity(self.X, self.Y)
 972            self.u0 = np.copy(self.u_prev)
 973            self.v0 = np.copy(self.v_prev)
 974    
 975            # Calculation of u_prev2 (initial acceleration)
 976            if not hasattr(self, 'u_prev2'):
 977                if self.has_psi:
 978                    acc0 = -self.apply_psiOp(self.u_prev)
 979                else:
 980                    acc0 = self.linear_rhs(self.u_prev, is_v=False)
 981                rhs_nl = self.apply_nonlinear(self.u_prev, is_v=False)
 982                acc0 += rhs_nl
 983                if hasattr(self, 'source_terms') and self.source_terms:
 984                    acc0 += self.evaluate_source_at_t0()
 985                self.u_prev2 = self.u_prev - self.dt * self.v_prev + 0.5 * self.dt**2 * acc0
 986    
 987        self.frames = [self.u_prev.copy()]
 988           
 989    def apply_boundary(self, u):
 990        """
 991        Apply boundary conditions to the solution array based on the specified type.
 992    
 993        This method supports two types of boundary conditions:
 994        
 995        - 'periodic': Enforces periodicity by copying opposite boundary values.
 996        - 'dirichlet': Sets all boundary values to zero (homogeneous Dirichlet condition).
 997    
 998        Parameters
 999        ----------
1000        u : np.ndarray
1001            The solution array representing the field values on a spatial grid.
1002            In 1D, shape must be (Nx,). In 2D, shape must be (Nx, Ny).
1003    
1004        Raises
1005        ------
1006        ValueError
1007            If `self.boundary_condition` is not one of {'periodic', 'dirichlet'}.
1008    
1009        Notes
1010        -----
1011        - For 'periodic':
1012            * In 1D: u[0] = u[-2], u[-1] = u[1]
1013            * In 2D: First and last rows/columns are set equal to their neighbors.
1014        - For 'dirichlet':
1015            * All boundary points are explicitly set to zero.
1016        """
1017    
1018        if self.boundary_condition == 'periodic':
1019            if self.dim == 1:
1020                u[0] = u[-2]
1021                u[-1] = u[1]
1022            elif self.dim == 2:
1023                u[0, :] = u[-2, :]
1024                u[-1, :] = u[1, :]
1025                u[:, 0] = u[:, -2]
1026                u[:, -1] = u[:, 1]
1027    
1028        elif self.boundary_condition == 'dirichlet':
1029            if self.dim == 1:
1030                u[0] = 0
1031                u[-1] = 0
1032            elif self.dim == 2:
1033                u[0, :] = 0
1034                u[-1, :] = 0
1035                u[:, 0] = 0
1036                u[:, -1] = 0
1037    
1038        else:
1039            raise ValueError(
1040                f"Invalid boundary condition '{self.boundary_condition}'. "
1041                "Supported types are 'periodic' and 'dirichlet'."
1042            )
1043
1044    def apply_nonlinear(self, u, is_v=False):
1045        """
1046        Apply nonlinear terms to the solution using spectral differentiation with dealiasing.
1047
1048        This method evaluates all nonlinear terms present in the PDE by substituting spatial 
1049        derivatives with their spectral approximations computed via FFT. The dealiasing mask 
1050        ensures numerical stability by removing high-frequency components that could lead 
1051        to aliasing errors.
1052
1053        Parameters
1054        ----------
1055        u : numpy.ndarray
1056            Current solution array on the spatial grid.
1057        is_v : bool
1058            If True, evaluates nonlinear terms for the velocity field v instead of u.
1059
1060        Returns:
1061            numpy.ndarray: Array representing the contribution of nonlinear terms multiplied by dt.
1062
1063        Notes:
1064        
1065        - In 1D, computes ∂ₓu via FFT and substitutes any derivative term in the nonlinear expressions.
1066        - In 2D, computes ∂ₓu and ∂ᵧu via FFT and performs similar substitutions.
1067        - Uses lambdify to evaluate symbolic nonlinear expressions numerically.
1068        - Derivatives are replaced symbolically with 'u_x' and 'u_y' before evaluation.
1069        """
1070        if not self.nonlinear_terms:
1071            return np.zeros_like(u, dtype=np.complex128)
1072        
1073        nonlinear_term = np.zeros_like(u, dtype=np.complex128)
1074    
1075        if self.dim == 1:
1076            u_hat = self.fft(u)
1077            u_hat *= self.dealiasing_mask
1078            u = self.ifft(u_hat)
1079    
1080            u_x_hat = (1j * self.KX) * u_hat
1081            u_x = self.ifft(u_x_hat)
1082    
1083            for term in self.nonlinear_terms:
1084                term_replaced = term
1085                if term.has(Derivative):
1086                    for deriv in term.atoms(Derivative):
1087                        if deriv.args[1][0] == self.x:
1088                            term_replaced = term_replaced.subs(deriv, symbols('u_x'))            
1089                term_func = lambdify((self.t, self.x, self.u_eq, 'u_x'), term_replaced, 'numpy')
1090                if is_v:
1091                    nonlinear_term += term_func(0, self.X, self.v_prev, u_x)
1092                else:
1093                    nonlinear_term += term_func(0, self.X, u, u_x)
1094    
1095        elif self.dim == 2:
1096            u_hat = self.fft(u)
1097            u_hat *= self.dealiasing_mask
1098            u = self.ifft(u_hat)
1099    
1100            u_x_hat = (1j * self.KX) * u_hat
1101            u_y_hat = (1j * self.KY) * u_hat
1102            u_x = self.ifft(u_x_hat)
1103            u_y = self.ifft(u_y_hat)
1104    
1105            for term in self.nonlinear_terms:
1106                term_replaced = term
1107                if term.has(Derivative):
1108                    for deriv in term.atoms(Derivative):
1109                        if deriv.args[1][0] == self.x:
1110                            term_replaced = term_replaced.subs(deriv, symbols('u_x'))
1111                        elif deriv.args[1][0] == self.y:
1112                            term_replaced = term_replaced.subs(deriv, symbols('u_y'))
1113                term_func = lambdify((self.t, self.x, self.y, self.u_eq, 'u_x', 'u_y'), term_replaced, 'numpy')
1114                if is_v:
1115                    nonlinear_term += term_func(0, self.X, self.Y, self.v_prev, u_x, u_y)
1116                else:
1117                    nonlinear_term += term_func(0, self.X, self.Y, u, u_x, u_y)
1118        else:
1119            raise ValueError("Unsupported spatial dimension.")
1120        
1121        return nonlinear_term * self.dt
1122
1123    def prepare_symbol_tables(self):
1124        """
1125        Precompute and store evaluated pseudo-differential operator symbols for spectral methods.
1126
1127        This method evaluates all pseudo-differential operators (ψOp) present in the PDE
1128        over the spatial and frequency grids, scales them by their respective coefficients,
1129        and combines them into a single composite symbol used in time-stepping and inversion.
1130
1131        The evaluation is performed via the `evaluate` method of each PseudoDifferentialOperator,
1132        which computes p(x, ξ) or p(x, y, ξ, η) numerically over the current grid configuration.
1133
1134        Side Effects:
1135            self.precomputed_symbols : list of (coeff, symbol_array)
1136                Each tuple contains a coefficient and its evaluated symbol on the grid.
1137            self.combined_symbol : np.ndarray
1138                Sum of all scaled symbol arrays: ∑(coeffₖ * ψₖ(x, ξ))
1139
1140        Raises:
1141            ValueError: If the spatial dimension is not 1D or 2D.
1142        """
1143        self.precomputed_symbols = []
1144        self.combined_symbol = 0
1145        for coeff, psi in self.psi_ops:
1146            if self.dim == 1:
1147                raw = psi.evaluate(self.X, None, self.KX, None)
1148            elif self.dim == 2:
1149                raw = psi.evaluate(self.X, self.Y, self.KX, self.KY)
1150            else:
1151                raise ValueError('Unsupported spatial dimension.')
1152            raw_flat = raw.flatten()
1153            converted = np.array([complex(N(val)) for val in raw_flat], dtype=np.complex128)
1154            raw_eval = converted.reshape(raw.shape)
1155            self.precomputed_symbols.append((coeff, raw_eval))
1156        self.combined_symbol = sum((coeff * sym for coeff, sym in self.precomputed_symbols))
1157        self.combined_symbol = np.array(self.combined_symbol, dtype=np.complex128)
1158
1159    def total_symbol_expr(self):
1160        """
1161        Compute the total pseudo-differential symbol expression from all pseudo_terms.
1162
1163        This method constructs the full symbol of the pseudo-differential operator
1164        by summing up all coefficient-weighted symbolic expressions.
1165
1166        The result is cached in self.symbol_expr to avoid recomputation.
1167
1168        Returns:
1169            sympy.Expr: The combined symbol expression, representing the full
1170                        pseudo-differential operator in symbolic form.
1171
1172        Example:
1173            Given pseudo_terms = [(2, ξ²), (1, x·ξ)], this returns 2·ξ² + x·ξ.
1174        """
1175        if not hasattr(self, '_symbol_expr'):
1176            self.symbol_expr = sum(coeff * expr for coeff, expr in self.pseudo_terms)
1177        return self.symbol_expr
1178
1179    def build_symbol_func(self, expr):
1180        """
1181        Build a numerical evaluation function from a symbolic pseudo-differential operator expression.
1182    
1183        This method converts a symbolic expression representing a pseudo-differential operator into
1184        a callable NumPy-compatible function. The function accepts spatial and frequency variables
1185        depending on the dimensionality of the problem.
1186    
1187        Parameters
1188        ----------
1189        expr : sympy expression
1190            A SymPy expression representing the symbol of the pseudo-differential operator. It may depend on spatial variables (x, y) and frequency variables (xi, eta).
1191    
1192        Returns:
1193            function : A lambdified function that takes:
1194            
1195                - In 1D: `(x, xi)` — spatial coordinate and frequency.
1196                - In 2D: `(x, y, xi, eta)` — spatial coordinates and frequencies.
1197                
1198              Returns a NumPy array of evaluated symbol values over input grids.
1199    
1200        Notes:
1201            - Uses `lambdify` from SymPy with the `'numpy'` backend for efficient vectorized evaluation.
1202            - Real variable assumptions are enforced to ensure proper behavior in numerical contexts.
1203            - Used internally by methods like `apply_psiOp`, `evaluate`, and visualization tools.
1204        """
1205        if self.dim == 1:
1206            x, xi = symbols('x xi', real=True)
1207            return lambdify((x, xi), expr, 'numpy')
1208        else:
1209            x, y, xi, eta = symbols('x y xi eta', real=True)
1210            return lambdify((x, y, xi, eta), expr, 'numpy')
1211
1212    def apply_psiOp(self, u):
1213        """
1214        Apply the pseudo-differential operator to the input field u.
1215    
1216        This method dispatches the application of the pseudo-differential operator based on:
1217        
1218        - Whether the symbol is spatially dependent (x/y)
1219        - The boundary condition in use (periodic or dirichlet)
1220    
1221        Supported operations:
1222        
1223        - Constant-coefficient symbols: applied via Fourier multiplication.
1224        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
1225        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
1226    
1227        Dispatch Logic:\n
1228        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
1229        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
1230        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
1231        
1232        This method delegates to the apply() method of each 
1233        PseudoDifferentialOperator instance.
1234        
1235        Parameters
1236        ----------
1237        u : ndarray
1238            Function to which operators are applied
1239            
1240        Returns
1241        -------
1242        ndarray
1243            Result of applying all operators with their coefficients
1244        """
1245        if not hasattr(self, 'psi_ops') or not self.psi_ops:
1246            raise ValueError("No pseudo-differential operators defined")
1247        
1248        result = np.zeros_like(u, dtype=np.complex128)
1249        
1250        for coeff, psi_op in self.psi_ops:
1251            coeff = np.complex128(coeff)
1252            if self.dim == 1:
1253                contribution = psi_op.apply(
1254                    u=u,
1255                    x_grid=self.x_grid,
1256                    kx=self.kx,
1257                    boundary_condition=self.boundary_condition,
1258                    dealiasing_mask=self.dealiasing_mask
1259                )
1260            elif self.dim == 2:
1261                contribution = psi_op.apply(
1262                    u=u,
1263                    x_grid=self.x_grid,
1264                    kx=self.kx,
1265                    y_grid=self.y_grid,
1266                    ky=self.ky,
1267                    boundary_condition=self.boundary_condition,
1268                    dealiasing_mask=self.dealiasing_mask
1269                )
1270            else:
1271                raise ValueError("Only 1D and 2D supported")
1272            
1273            result += coeff * contribution
1274        
1275        return result
1276
1277    def step_order1_with_psi(self, source_contribution):
1278        """
1279        Perform one time step of a first-order evolution using a pseudo-differential operator.
1280    
1281        This method updates the solution field using an exponential integrator or explicit Euler scheme,
1282        depending on boundary conditions and the structure of the pseudo-differential symbol.
1283        It supports:
1284        - Linear dynamics via pseudo-differential operator L (possibly nonlocal)
1285        - Nonlinear terms computed via spectral differentiation
1286        - External source contributions
1287    
1288        The update follows **three distinct computational paths**:
1289    
1290        1. **Periodic boundaries + diagonalizable symbol**  
1291           Symbol is constant in space → use direct Fourier-based exponential integrator:  
1292               uₙ₊₁ = e⁻ᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(−LΔt) ⋅ (N(uₙ) + F)
1293    
1294        2. **Non-diagonalizable but spatially uniform symbol**  
1295           General exponential time differencing of order 1:  
1296               uₙ₊₁ = eᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(LΔt) ⋅ (N(uₙ) + F)
1297    
1298        3. **Spatially varying symbol**  
1299           No frequency diagonalization available → use explicit Euler:  
1300               uₙ₊₁ = uₙ + Δt ⋅ (L(uₙ) + N(uₙ) + F)
1301    
1302        where:
1303            L(uₙ) = linear part via pseudo-differential operator
1304            N(uₙ) = nonlinear contribution at current time step
1305            F     = external source term
1306            Δt    = time step size
1307            φ₁(z) = (eᶻ − 1)/z (with safe handling near z=0)
1308    
1309        Boundary conditions are applied after each update to ensure consistency.
1310    
1311        Parameters
1312            source_contribution (np.ndarray): Array representing the external source term at current time step.
1313                                              Must match the spatial dimensions of self.u_prev.
1314    
1315        Returns:
1316            np.ndarray: Updated solution array after one time step.
1317        """
1318        # Handling null source
1319        if np.isscalar(source_contribution):
1320            source = np.zeros_like(self.u_prev)
1321        else:
1322            source = source_contribution
1323
1324        def spectral_filter(u, cutoff=0.8):
1325            if u.ndim == 1:
1326                u_hat = self.fft(u)
1327                N = len(u)
1328                k = fftfreq(N)
1329                mask = np.exp(-(k / cutoff)**8)
1330                return self.ifft(u_hat * mask).real
1331            elif u.ndim == 2:
1332                u_hat = self.fft(u)
1333                Ny, Nx = u.shape
1334                ky = fftfreq(Ny)[:, None]
1335                kx = fftfreq(Nx)[None, :]
1336                k_squared = kx**2 + ky**2
1337                mask = np.exp(-(np.sqrt(k_squared) / cutoff)**8)
1338                return self.ifft(u_hat * mask).real
1339            else:
1340                raise ValueError("Only 1D and 2D arrays are supported.")
1341
1342        # Recalculate symbol if necessary
1343        if self.is_spatial:
1344            self.prepare_symbol_tables()  # Recalculates self.combined_symbol
1345    
1346        # Case with FFT (symbol diagonalizable in Fourier space)
1347        if self.boundary_condition == 'periodic' and not self.is_spatial:
1348            u_hat = self.fft(self.u_prev)
1349            u_hat *= np.exp(-self.dt * self.combined_symbol)
1350            u_hat *= self.dealiasing_mask
1351            u_symb = self.ifft(u_hat)
1352            u_nl = self.apply_nonlinear(self.u_prev)
1353            u_new = u_symb + u_nl + source
1354        else:
1355            if not self.is_spatial:
1356                # General case with ETD1
1357                u_nl = self.apply_nonlinear(self.u_prev)
1358    
1359                # Calculation of exp(dt * L) and phi1(dt * L)
1360                L_vals = self.combined_symbol  # Uses the updated symbol
1361                exp_L = np.exp(-self.dt * L_vals)
1362                phi1_L = (exp_L - 1.0) / (self.dt * L_vals)
1363                phi1_L[np.isnan(phi1_L)] = 1.0  # Handling division by zero
1364    
1365                # Fourier transform
1366                u_hat = self.fft(self.u_prev)
1367                u_nl_hat = self.fft(u_nl)
1368                source_hat = self.fft(source)
1369    
1370                # Assembling the solution in Fourier space
1371                u_hat_new = exp_L * u_hat + self.dt * phi1_L * (u_nl_hat + source_hat)
1372                u_new = self.ifft(u_hat_new)
1373            else:
1374                # if the symbol depends on spatial variables : Euler method
1375                Lu_prev = -self.apply_psiOp(self.u_prev)
1376                u_nl = self.apply_nonlinear(self.u_prev)
1377                u_new = self.u_prev + self.dt * (Lu_prev + u_nl + source)
1378                u_new = spectral_filter(u_new, cutoff=self.dealiasing_ratio)
1379        # Applying boundary conditions
1380        self.apply_boundary(u_new)
1381        return u_new
1382
1383    def step_order2_with_psi(self, source_contribution):
1384        """
1385        Perform one time step of a second-order time evolution using a pseudo-differential operator.
1386    
1387        This method updates the solution field using a second-order accurate scheme suitable for wave-like equations.
1388        The update includes contributions from:
1389        - Linear dynamics via a pseudo-differential operator (e.g., dispersion or stiffness)
1390        - Nonlinear terms computed via spectral differentiation
1391        - External source contributions
1392    
1393        Discretization follows a leapfrog-style finite difference in time:
1394        
1395            uₙ₊₁ = 2uₙ − uₙ₋₁ + Δt² ⋅ (L(uₙ) + N(uₙ) + F)
1396    
1397        where:
1398            L(uₙ) = linear part evaluated via pseudo-differential operator
1399            N(uₙ) = nonlinear contribution at current time step
1400            F     = external source term at current time step
1401            Δt    = time step size
1402    
1403        Boundary conditions are applied after each update to ensure consistency.
1404    
1405        Parameters
1406            source_contribution (np.ndarray): Array representing the external source term at current time step.
1407                                              Must match the spatial dimensions of self.u_prev.
1408    
1409        Returns:
1410            np.ndarray: Updated solution array after one time step.
1411        """
1412        Lu_prev = -self.apply_psiOp(self.u_prev)
1413        rhs_nl = self.apply_nonlinear(self.u_prev, is_v=False)
1414        u_new = 2 * self.u_prev - self.u_prev2 + self.dt ** 2 * (Lu_prev + rhs_nl + source_contribution)
1415        self.apply_boundary(u_new)
1416        self.u_prev2 = self.u_prev
1417        self.u_prev = u_new
1418        self.u = u_new
1419        return u_new
1420
1421    def solve(self):
1422        """
1423        Solve the partial differential equation numerically using spectral methods.
1424        
1425        This method evolves the solution in time using a combination of:
1426        - Fourier-based linear evolution (with dealiasing)
1427        - Nonlinear term handling via pseudo-spectral evaluation
1428        - Support for pseudo-differential operators (psiOp)
1429        - Source terms and boundary conditions
1430        
1431        The solver supports:
1432        - 1D and 2D spatial domains
1433        - First and second-order time evolution
1434        - Periodic and Dirichlet boundary conditions
1435        - Time-stepping schemes: default, ETD-RK4
1436        
1437        Returns:
1438            list[np.ndarray]: A list of solution arrays at each saved time frame.
1439        
1440        Side Effects:
1441            - Updates self.frames: stores solution snapshots
1442            - Updates self.energy_history: records total energy if enabled
1443            
1444        Algorithm Overview:
1445            For each time step:
1446                1. Evaluate source contributions (if any)
1447                2. Apply time evolution:
1448                    - Order 1:
1449                        - With psiOp: uses step_order1_with_psi
1450                        - With ETD-RK4: exponential time differencing
1451                        - Default: linear + nonlinear update
1452                    - Order 2:
1453                        - With psiOp: uses step_order2_with_psi
1454                        - With ETD-RK4: second-order exponential scheme
1455                        - Default: second-order leapfrog-style update
1456                3. Enforce boundary conditions
1457                4. Save solution snapshot periodically
1458                5. Record energy (for second-order systems without psiOp)
1459        """
1460        print('\n*******************')
1461        print('* Solving the PDE *')
1462        print('*******************\n')
1463        save_interval = max(1, self.Nt // self.n_frames)
1464        self.energy_history = []
1465        for step in range(self.Nt):
1466            if hasattr(self, 'source_terms') and self.source_terms:
1467                source_contribution = np.zeros_like(self.X, dtype=np.float64)
1468                for term in self.source_terms:
1469                    try:
1470                        if self.dim == 1:
1471                            source_func = lambdify((self.t, self.x), term, 'numpy')
1472                            source_contribution += source_func(step * self.dt, self.X)
1473                        elif self.dim == 2:
1474                            source_func = lambdify((self.t, self.x, self.y), term, 'numpy')
1475                            source_contribution += source_func(step * self.dt, self.X, self.Y)
1476                    except Exception as e:
1477                        print(f'Error evaluating source term {term}: {e}')
1478            else:
1479                source_contribution = 0
1480
1481            if self.temporal_order == 1:
1482                if self.has_psi:
1483                    u_new = self.step_order1_with_psi(source_contribution)
1484                elif hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1485                    u_new = self.step_ETD_RK4(self.u_prev)
1486                else:
1487                    u_hat = self.fft(self.u_prev)
1488                    u_hat *= self.exp_L
1489                    u_hat *= self.dealiasing_mask
1490                    u_lin = self.ifft(u_hat)
1491                    u_nl = self.apply_nonlinear(u_lin)
1492                    u_new = u_lin + u_nl + source_contribution
1493                self.apply_boundary(u_new)
1494                self.u_prev = u_new
1495
1496            elif self.temporal_order == 2:
1497                if self.has_psi:
1498                    u_new = self.step_order2_with_psi(source_contribution)
1499                else:
1500                    if hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1501                        u_new, v_new = self.step_ETD_RK4_order2(self.u_prev, self.v_prev)
1502                    else:
1503                        u_hat = self.fft(self.u_prev)
1504                        v_hat = self.fft(self.v_prev)
1505                        u_new_hat = self.cos_omega_dt * u_hat + self.sin_omega_dt * self.inv_omega * v_hat
1506                        v_new_hat = -self.omega_val * self.sin_omega_dt * u_hat + self.cos_omega_dt * v_hat
1507                        u_new = self.ifft(u_new_hat)
1508                        v_new = self.ifft(v_new_hat)
1509                        u_nl = self.apply_nonlinear(self.u_prev, is_v=False)
1510                        v_nl = self.apply_nonlinear(self.v_prev, is_v=True)
1511                        u_new += (u_nl + source_contribution) * self.dt ** 2 / 2
1512                        v_new += (u_nl + source_contribution) * self.dt
1513                    self.apply_boundary(u_new)
1514                    self.apply_boundary(v_new)
1515                    self.u_prev = u_new
1516                    self.v_prev = v_new
1517
1518            if step % save_interval == 0:
1519                self.frames.append(self.u_prev.copy())
1520
1521            if self.temporal_order == 2 and (not self.has_psi):
1522                E = self.compute_energy()
1523                self.energy_history.append(E)
1524
1525        return self.frames  
1526                
1527    def solve_stationary_psiOp(self, order=3):
1528        """
1529        Solve stationary pseudo-differential equations of the form P[u] = f(x) or P[u] = f(x,y) using asymptotic inversion.
1530    
1531        This method computes the solution to a stationary (time-independent) pseudo-differential equation
1532        where the operator P is defined via symbolic expressions (psiOp). It constructs an asymptotic right inverse R 
1533        such that P∘R ≈ Id, then applies it to the source term f using either direct Fourier multiplication 
1534        (when the symbol is spatially independent) or Kohn–Nirenberg quantization (when spatial dependence is present).
1535    
1536        The inversion is based on the principal symbol of the operator and its asymptotic expansion up to the given order.
1537        Ellipticity of the symbol is checked numerically before inversion to ensure well-posedness.
1538    
1539        Parameters
1540        ----------
1541        order : int, default=3
1542            Order of the asymptotic expansion used to construct the right inverse of the pseudo-differential operator.
1543        method : str, optional
1544            Inversion strategy:
1545            - 'diagonal' (default): Fast approximate inversion using diagonal operators in frequency space.
1546            - 'full'                : Pointwise exact inversion (slower but more accurate).
1547    
1548        Returns
1549        -------
1550        ndarray
1551            The computed solution u(x) in 1D or u(x, y) in 2D as a NumPy array over the spatial grid.
1552    
1553        Raises
1554        ------
1555        ValueError
1556            If no pseudo-differential operator (psiOp) is defined.
1557            If linear or nonlinear terms other than psiOp are present.
1558            If the symbol is not elliptic on the grid.
1559            If no source term is provided for the right-hand side.
1560    
1561        Notes
1562        -----
1563        - The method assumes the problem is fully stationary: time derivatives must be absent.
1564        - Requires the equation to be purely pseudo-differential (no Op, Derivative, or nonlinear terms).
1565        - Symbol evaluation and inversion are dimension-aware (supports both 1D and 2D problems).
1566        - Supports optimization paths when the symbol does not depend on spatial variables.
1567    
1568        See Also
1569        --------
1570        right_inverse_asymptotic : Constructs the asymptotic inverse of the pseudo-differential operator.
1571        kohn_nirenberg           : Numerical implementation of general pseudo-differential operators.
1572        is_elliptic_numerically  : Verifies numerical ellipticity of the symbol.
1573        """
1574
1575        print("\n*******************************")
1576        print("* Solving the stationnary PDE *")
1577        print("*******************************\n")
1578        print("boundary condition: ",self.boundary_condition)
1579        
1580
1581        if not self.has_psi:
1582            raise ValueError("Only supports problems with psiOp.")
1583    
1584        if self.linear_terms or self.nonlinear_terms:
1585            raise ValueError("Stationary psiOp problems must be linear and purely pseudo-differential.")
1586
1587        if self.boundary_condition not in ('periodic', 'dirichlet'):
1588            raise ValueError(
1589                "For stationary PDEs, boundary conditions must be explicitly defined. "
1590                "Supported types are 'periodic' and 'dirichlet'."
1591            )    
1592            
1593        if self.dim == 1:
1594            x = self.x
1595            xi = symbols('xi', real=True)
1596            spatial_vars = (x,)
1597            freq_vars = (xi,)
1598            X, KX = self.X, self.KX
1599        elif self.dim == 2:
1600            x, y = self.x, self.y
1601            xi, eta = symbols('xi eta', real=True)
1602            spatial_vars = (x, y)
1603            freq_vars = (xi, eta)
1604            X, Y, KX, KY = self.X, self.Y, self.KX, self.KY
1605        else:
1606            raise ValueError("Unsupported spatial dimension.")
1607    
1608        total_symbol = sum(coeff * psi.expr for coeff, psi in self.psi_ops)
1609        psi_total = PseudoDifferentialOperator(total_symbol, spatial_vars, mode='symbol')
1610    
1611        # Check ellipticity
1612        if self.dim == 1:
1613            is_elliptic = psi_total.is_elliptic_numerically(X, KX)
1614        else:
1615            is_elliptic = psi_total.is_elliptic_numerically((X[:, 0], Y[0, :]), (KX[:, 0], KY[0, :]))
1616        if not is_elliptic:
1617            raise ValueError("❌ The pseudo-differential symbol is not numerically elliptic on the grid.")
1618        print("✅ Elliptic pseudo-differential symbol: inversion allowed.")
1619    
1620        R_symbol = psi_total.right_inverse_asymptotic(order=order)
1621        print('Right inverse asymptotic symbol:')
1622        pprint(R_symbol, num_columns=NUM_COLS)
1623        
1624        # ========================================================================
1625        # FIX: Always lambdify with all variables for consistency
1626        # ========================================================================
1627        if self.dim == 1:
1628            # Always include both x and xi in the signature
1629            R_func = lambdify((x, xi), R_symbol, modules='numpy')
1630        elif self.dim == 2:
1631            # Always include all four variables
1632            R_func = lambdify((x, y, xi, eta), R_symbol, modules='numpy')
1633        
1634        # Prepare right-hand side
1635        if self.source_terms:
1636            f_expr = sum(self.source_terms)
1637            used_vars = [v for v in spatial_vars if f_expr.has(v)]
1638            f_func = lambdify(used_vars, -f_expr, modules='numpy')
1639            if self.dim == 1:
1640                rhs = f_func(self.x_grid) if used_vars else np.zeros_like(self.x_grid)
1641            else:
1642                rhs = f_func(self.X, self.Y) if used_vars else np.zeros_like(self.X)
1643        elif self.initial_condition:
1644            raise ValueError('Initial condition should be None for stationnary equation.')
1645        else:
1646            raise ValueError('No source term provided to construct the right-hand side.')
1647        
1648        f_hat = self.fft(rhs)
1649        
1650        # ========================================================================
1651        # Application of the inverse operator
1652        # ========================================================================
1653        if self.boundary_condition == 'periodic':
1654            if self.dim == 1:
1655                # Check if optimization is possible
1656                if not R_symbol.has(x):
1657                    print('⚡ Optimization: symbol independent of x – direct product in Fourier.')
1658                    # Create wrapper that ignores x
1659                    def R_func_optimized(kx_val):
1660                        return R_func(0.0, kx_val)  # x=0 since it doesn't matter
1661                    
1662                    R_vals = R_func_optimized(self.KX)
1663                    u_hat = R_vals * f_hat
1664                    u = self.ifft(u_hat)
1665                else:
1666                    print('⚙️ 1D Kohn-Nirenberg Quantification')
1667                    from psiop import kohn_nirenberg_fft
1668                    u = kohn_nirenberg_fft(
1669                        u_vals=rhs,
1670                        symbol_func=R_func,  # Now has correct signature (x, xi)
1671                        x_grid=self.x_grid,
1672                        kx=self.kx,
1673                        fft_func=self.fft,
1674                        ifft_func=self.ifft,
1675                        dim=1
1676                    )
1677                    
1678            elif self.dim == 2:
1679                if not R_symbol.has(x) and not R_symbol.has(y):
1680                    print('⚡ Optimization: Symbol independent of x and y – direct product in 2D Fourier.')
1681                    # Create wrapper that ignores x, y
1682                    def R_func_optimized(kx_val, ky_val):
1683                        return R_func(0.0, 0.0, kx_val, ky_val)
1684                    
1685                    R_vals = R_func_optimized(self.KX, self.KY)
1686                    u_hat = R_vals * f_hat
1687                    u = self.ifft(u_hat)
1688                else:
1689                    print('⚙️ 2D Kohn-Nirenberg Quantification')
1690                    from psiop import kohn_nirenberg_fft
1691                    u = kohn_nirenberg_fft(
1692                        u_vals=rhs,
1693                        symbol_func=R_func,  # Now has correct signature (x, y, xi, eta)
1694                        x_grid=self.x_grid,
1695                        kx=self.kx,
1696                        fft_func=self.fft,
1697                        ifft_func=self.ifft,
1698                        dim=2,
1699                        y_grid=self.y_grid,
1700                        ky=self.ky
1701                    )
1702            self.u = u
1703            return u
1704            
1705        elif self.boundary_condition == 'dirichlet':
1706            from psiop import kohn_nirenberg_nonperiodic
1707            
1708            if self.dim == 1:
1709                u = kohn_nirenberg_nonperiodic(
1710                    u_vals=rhs,
1711                    x_grid=self.x_grid,
1712                    xi_grid=self.kx,
1713                    symbol_func=R_func  # Now has correct signature (x, xi)
1714                )
1715            elif self.dim == 2:
1716                u = kohn_nirenberg_nonperiodic(
1717                    u_vals=rhs,
1718                    x_grid=(self.x_grid, self.y_grid),
1719                    xi_grid=(self.kx, self.ky),
1720                    symbol_func=R_func  # Now has correct signature (x, y, xi, eta)
1721                )
1722            self.u = u
1723            return u
1724        
1725        else:
1726            raise ValueError(f"Invalid boundary condition '{self.boundary_condition}'. Supported types are 'periodic' and 'dirichlet'.")
1727        
1728    def step_ETD_RK4(self, u):
1729        """
1730        Perform one Exponential Time Differencing Runge-Kutta of 4th order (ETD-RK4) time step 
1731        for first-order in time PDEs of the form:
1732        
1733            ∂ₜu = L u + N(u)
1734        
1735        where L is a linear operator (possibly nonlocal or pseudo-differential), and N is a 
1736        nonlinear term treated via pseudo-spectral methods. This method evaluates the 
1737        exponential integrator up to fourth-order accuracy in time.
1738    
1739        The ETD-RK4 scheme uses four stages to approximate the integral of the variation-of-constants formula:
1740        
1741            uⁿ⁺¹ = e^(L Δt) uⁿ + Δt ∫₀¹ e^(L Δt (1 - τ)) φ(N(u(τ))) dτ
1742        
1743        where φ denotes the nonlinear contributions evaluated at intermediate stages.
1744    
1745        Parameters
1746            u (np.ndarray): Current solution in real space (physical grid values).
1747    
1748        Returns:
1749            np.ndarray: Updated solution in real space after one ETD-RK4 time step.
1750    
1751        Notes:
1752        - The linear part L is diagonal in Fourier space and precomputed as self.L(k).
1753        - Nonlinear terms are evaluated in physical space and transformed via FFT.
1754        - The functions φ₁(z) and φ₂(z) are entire functions arising from the ETD scheme:
1755          
1756              φ₁(z) = (eᶻ - 1)/z   if z ≠ 0
1757                     = 1            if z = 0
1758    
1759              φ₂(z) = (eᶻ - 1 - z)/z²   if z ≠ 0
1760                     = ½              if z = 0
1761    
1762        - This implementation assumes periodic boundary conditions and uses spectral differentiation via FFT.
1763        - See Hochbruck & Ostermann (2010) for theoretical background on exponential integrators.
1764    
1765        See Also:
1766            step_ETD_RK4_order2 : For second-order in time equations.
1767            psiOp_apply           : For applying pseudo-differential operators.
1768            apply_nonlinear      : For handling nonlinear terms in the PDE.
1769        """
1770        dt = self.dt
1771        L_fft = self.L(self.KX) if self.dim == 1 else self.L(self.KX, self.KY)
1772    
1773        E  = np.exp(dt * L_fft)
1774        E2 = np.exp(dt * L_fft / 2)
1775    
1776        def phi1(z):
1777            return np.where(np.abs(z) > 1e-12, (np.exp(z) - 1) / z, 1.0)
1778    
1779        def phi2(z):
1780            return np.where(np.abs(z) > 1e-12, (np.exp(z) - 1 - z) / z**2, 0.5)
1781    
1782        phi1_dtL = phi1(dt * L_fft)
1783        phi2_dtL = phi2(dt * L_fft)
1784    
1785        fft = self.fft
1786        ifft = self.ifft
1787    
1788        u_hat = fft(u)
1789        N1 = fft(self.apply_nonlinear(u))
1790    
1791        a = ifft(E2 * (u_hat + 0.5 * dt * N1 * phi1_dtL))
1792        N2 = fft(self.apply_nonlinear(a))
1793    
1794        b = ifft(E2 * (u_hat + 0.5 * dt * N2 * phi1_dtL))
1795        N3 = fft(self.apply_nonlinear(b))
1796    
1797        c = ifft(E * (u_hat + dt * N3 * phi1_dtL))
1798        N4 = fft(self.apply_nonlinear(c))
1799    
1800        u_new_hat = E * u_hat + dt * (
1801            N1 * phi1_dtL + 2 * (N2 + N3) * phi2_dtL + N4 * phi1_dtL
1802        ) / 6
1803    
1804        return ifft(u_new_hat)
1805
1806    def step_ETD_RK4_order2(self, u, v):
1807        """
1808        Perform one time step of the Exponential Time Differencing Runge-Kutta 4th-order (ETD-RK4) scheme for second-order PDEs.
1809    
1810        This method evolves the solution u and its time derivative v forward in time by one step using the ETD-RK4 integrator. 
1811        It is designed for systems of the form:
1812        
1813            ∂ₜ²u = L u + N(u)
1814            
1815        where L is a linear operator and N is a nonlinear term computed via self.apply_nonlinear.
1816        
1817        The exponential integrator handles the linear part exactly in Fourier space, while the nonlinear terms are integrated 
1818        using a fourth-order Runge-Kutta-like approach. This ensures high accuracy and stability for stiff systems.
1819    
1820        Parameters:
1821            u (np.ndarray): Current solution array in real space.
1822            v (np.ndarray): Current time derivative of the solution (∂ₜu) in real space.
1823    
1824        Returns:
1825            tuple: (u_new, v_new), updated solution and its time derivative after one time step.
1826    
1827        Notes:
1828            - Assumes periodic boundary conditions and uses FFT-based spectral methods.
1829            - Handles both 1D and 2D problems seamlessly.
1830            - Uses phi functions to compute exponential integrators efficiently.
1831            - Suitable for wave equations and other second-order evolution equations with stiffness.
1832        """
1833        dt = self.dt
1834    
1835        L_fft = self.L(self.KX) if self.dim == 1 else self.L(self.KX, self.KY)
1836        fft = self.fft
1837        ifft = self.ifft
1838    
1839        def rhs(u_val):
1840            return ifft(L_fft * fft(u_val)) + self.apply_nonlinear(u_val, is_v=False)
1841    
1842        # Stage A
1843        A = rhs(u)
1844        ua = u + 0.5 * dt * v
1845        va = v + 0.5 * dt * A
1846    
1847        # Stage B
1848        B = rhs(ua)
1849        ub = u + 0.5 * dt * va
1850        vb = v + 0.5 * dt * B
1851    
1852        # Stage C
1853        C = rhs(ub)
1854        uc = u + dt * vb
1855    
1856        # Stage D
1857        D = rhs(uc)
1858    
1859        # Final update
1860        u_new = u + dt * v + (dt**2 / 6.0) * (A + 2*B + 2*C + D)
1861        v_new = v + (dt / 6.0) * (A + 2*B + 2*C + D)
1862    
1863        return u_new, v_new
1864
1865    def check_cfl_condition(self):
1866        """
1867        Check the CFL (Courant–Friedrichs–Lewymann) condition based on group velocity 
1868        for second-order time-dependent PDEs.
1869    
1870        This method verifies whether the chosen time step dt satisfies the numerical stability 
1871        condition derived from the maximum wave propagation speed in the system. It supports both 
1872        1D and 2D problems, with or without a symbolic dispersion relation ω(k).
1873    
1874        The CFL condition ensures that information does not propagate further than one grid cell 
1875        per time step. A safety factor of 0.5 is applied by default to ensure robustness.
1876    
1877        Notes:
1878        
1879        - In 1D, the group velocity v₉(k) = dω/dk is used to compute the maximum wave speed.
1880        - In 2D, the x- and y-directional group velocities are evaluated independently.
1881        - If no dispersion relation is available, the imaginary part of the linear operator L(k) 
1882          is used as an approximation for wave speed.
1883    
1884        Raises:
1885        -------
1886        NotImplementedError: 
1887            If the spatial dimension is not 1D or 2D.
1888    
1889        Prints:
1890        -------
1891        Warning message if the current time step dt exceeds the CFL-stable limit.
1892        """
1893        print("\n*****************")
1894        print("* CFL condition *")
1895        print("*****************\n")
1896
1897        cfl_factor = 0.5  # Safety factor
1898        
1899        if self.dim == 1:
1900            if self.temporal_order == 2 and hasattr(self, 'omega'):
1901                k_vals = self.kx
1902                omega_vals = np.real(self.omega(k_vals))
1903                with np.errstate(divide='ignore', invalid='ignore'):
1904                    v_group = np.gradient(omega_vals, k_vals)
1905                max_speed = np.max(np.abs(v_group))
1906            else:
1907                max_speed = np.max(np.abs(np.imag(self.L(self.kx))))
1908            
1909            dx = self.Lx / self.Nx
1910            cfl_limit = cfl_factor * dx / max_speed if max_speed != 0 else np.inf
1911            
1912            if self.dt > cfl_limit:
1913                print(f"CFL condition violated: dt = {self.dt}, max allowed dt = {cfl_limit}")
1914    
1915        elif self.dim == 2:
1916            if self.temporal_order == 2 and hasattr(self, 'omega'):
1917                k_vals = self.kx
1918                omega_x = np.real(self.omega(k_vals, 0))
1919                omega_y = np.real(self.omega(0, k_vals))
1920                with np.errstate(divide='ignore', invalid='ignore'):
1921                    v_group_x = np.gradient(omega_x, k_vals)
1922                    v_group_y = np.gradient(omega_y, k_vals)
1923                max_speed_x = np.max(np.abs(v_group_x))
1924                max_speed_y = np.max(np.abs(v_group_y))
1925            else:
1926                max_speed_x = np.max(np.abs(np.imag(self.L(self.kx, 0))))
1927                max_speed_y = np.max(np.abs(np.imag(self.L(0, self.ky))))
1928            
1929            dx = self.Lx / self.Nx
1930            dy = self.Ly / self.Ny
1931            cfl_limit = cfl_factor / (max_speed_x / dx + max_speed_y / dy) if (max_speed_x + max_speed_y) != 0 else np.inf
1932            
1933            if self.dt > cfl_limit:
1934                print(f"CFL condition violated: dt = {self.dt}, max allowed dt = {cfl_limit}")
1935    
1936        else:
1937            raise NotImplementedError("Only 1D and 2D problems are supported.")
1938
1939    def check_symbol_conditions(self, k_range=None, verbose=True):
1940        """
1941        Check strict analytic conditions on the linear symbol self.L_symbolic:
1942            This method evaluates three key properties of the Fourier multiplier 
1943            symbol a(k) = self.L(k), which are crucial for well-posedness, stability,
1944            and numerical efficiency. The checks apply to both 1D and 2D cases.
1945        
1946        Conditions checked:
1947        ------------------
1948        1. **Stability condition**: Re(a(k)) ≤ 0 for all k ≠ 0
1949           Ensures that the system does not exhibit exponential growth in time.
1950    
1951        2. **Dissipation condition**: Re(a(k)) ≤ -δ |k|² for large |k|
1952           Ensures sufficient damping at high frequencies to avoid oscillatory instability.
1953    
1954        3. **Growth condition**: |a(k)| ≤ C (1 + |k|)^m with m ≤ 4
1955           Ensures that the symbol does not grow too rapidly with frequency, 
1956           which would otherwise cause numerical instability or unphysical amplification.
1957    
1958        Parameters
1959        ----------
1960        k_range : tuple or None, optional
1961            Specifies the range of frequencies to test in the form (k_min, k_max, N).
1962            If None, defaults are used: [-10, 10] with 500 points in 1D, or [-10, 10] 
1963            with 100 points per axis in 2D.
1964    
1965        verbose : bool, default=True
1966            If True, prints detailed results of each condition check.
1967    
1968        Returns:
1969        --------
1970        None
1971            Output is printed directly to the console for interpretability.
1972    
1973        Notes:
1974        ------
1975        - In 2D, the radial frequency |k| = √(kx² + ky²) is used for comparisons.
1976        - The dissipation threshold assumes δ = 0.01 and p = 2 by default.
1977        - The growth ratio is compared against |k|⁴; values above 100 indicate rapid growth.
1978        - This function is typically called during solver setup or analysis phase.
1979    
1980        See Also:
1981        ---------
1982        analyze_wave_propagation : For further symbolic and numerical analysis of dispersion.
1983        plot_symbol : Visualizes the symbol's behavior over the frequency domain.
1984        """
1985        print("\n********************")
1986        print("* Symbol condition *")
1987        print("********************\n")
1988
1989    
1990        if self.dim == 1:    
1991            if k_range is None:
1992                k_vals = np.linspace(-10, 10, 500)
1993            else:
1994                k_min, k_max, N = k_range
1995                k_vals = np.linspace(k_min, k_max, N)
1996    
1997            L_vals = self.L(k_vals)
1998            k_abs = np.abs(k_vals)
1999    
2000        elif self.dim == 2:
2001            if k_range is None:
2002                k_vals = np.linspace(-10, 10, 100)
2003            else:
2004                k_min, k_max, N = k_range
2005                k_vals = np.linspace(k_min, k_max, N)
2006    
2007            KX, KY = np.meshgrid(k_vals, k_vals)
2008            L_vals = self.L(KX, KY)
2009            k_abs = np.sqrt(KX**2 + KY**2)
2010    
2011        else:
2012            raise ValueError("Only 1D and 2D dimensions are supported.")
2013
2014    
2015        re_vals = np.real(L_vals)
2016        abs_vals = np.abs(L_vals)
2017    
2018        # === Condition 1: Stability
2019        if np.any(re_vals > 1e-12):
2020            max_pos = np.max(re_vals)
2021            if verbose:
2022                print(f"❌ Stability violated: max Re(a(k)) = {max_pos}")
2023            print("Unstable symbol: Re(a(k)) > 0")
2024        elif verbose:
2025            print("✅ Spectral stability satisfied: Re(a(k)) ≤ 0")
2026    
2027        # === Condition 2: Dissipation
2028        mask = k_abs > 2
2029        if np.any(mask):
2030            re_decay = re_vals[mask]
2031            expected_decay = -0.01 * k_abs[mask]**2
2032            if np.any(re_decay > expected_decay + 1e-6):
2033                if verbose:
2034                    print("⚠️ Insufficient high-frequency dissipation")
2035            else:
2036                if verbose:
2037                    print("✅ Proper high-frequency dissipation")
2038    
2039        # === Condition 3: Growth
2040        growth_ratio = abs_vals / (1 + k_abs)**4
2041        if np.max(growth_ratio) > 100:
2042            if verbose:
2043                print("⚠️ Symbol grows rapidly: |a(k)| ≳ |k|^4")
2044        else:
2045            if verbose:
2046                print("✅ Reasonable spectral growth")
2047    
2048        if verbose:
2049            print("✔ Symbol analysis completed.")
2050
2051    def analyze_wave_propagation(self):
2052        """
2053        Perform a detailed analysis of wave propagation characteristics based on the dispersion relation ω(k).
2054    
2055        This method visualizes key wave properties in both 1D and 2D settings:
2056        
2057        - Dispersion relation: ω(k)
2058        - Phase velocity: v_p(k) = ω(k)/|k|
2059        - Group velocity: v_g(k) = ∇ₖ ω(k)
2060        - Anisotropy in 2D (via magnitude of group velocity)
2061    
2062        The symbolic dispersion relation 'omega_symbolic' must be defined beforehand.
2063        This is typically available only for second-order-in-time equations.
2064    
2065        In 1D:
2066            Plots ω(k), v_p(k), and v_g(k) over a range of k values.
2067    
2068        In 2D:
2069            Displays heatmaps of ω(kx, ky), v_p(kx, ky), and |v_g(kx, ky)| over a 2D wavenumber grid.
2070    
2071        Raises:
2072            AttributeError: If 'omega_symbolic' is not defined, the method exits gracefully with a message.
2073    
2074        Side Effects:
2075            Generates and displays matplotlib plots.
2076        """
2077        print("\n*****************************")
2078        print("* Wave propagation analysis *")
2079        print("*****************************\n")
2080        if not hasattr(self, 'omega_symbolic'):
2081            print("❌ omega_symbolic not defined. Only available for 2nd order in time.")
2082            return
2083    
2084        if self.dim == 1:
2085            k = self.k_symbols[0]
2086            omega_func = lambdify(k, self.omega_symbolic, 'numpy')
2087    
2088            k_vals = np.linspace(-10, 10, 1000)
2089            omega_vals = omega_func(k_vals)
2090    
2091            with np.errstate(divide='ignore', invalid='ignore'):
2092                v_phase = np.where(k_vals != 0, omega_vals / k_vals, 0.0)
2093    
2094            dk = k_vals[1] - k_vals[0]
2095            v_group = np.gradient(omega_vals, dk)
2096    
2097            plt.figure(figsize=(10, 6))
2098            plt.plot(k_vals, omega_vals, label=r'$\omega(k)$')
2099            plt.plot(k_vals, v_phase, label=r'$v_p(k)$')
2100            plt.plot(k_vals, v_group, label=r'$v_g(k)$')
2101            plt.title("1D Wave Propagation Analysis")
2102            plt.xlabel("k")
2103            plt.grid()
2104            plt.legend()
2105            plt.tight_layout()
2106            plt.show()
2107    
2108        elif self.dim == 2:
2109            kx, ky = self.k_symbols
2110            omega_func = lambdify((kx, ky), self.omega_symbolic, 'numpy')
2111    
2112            k_vals = np.linspace(-10, 10, 200)
2113            KX, KY = np.meshgrid(k_vals, k_vals)
2114            K_mag = np.sqrt(KX**2 + KY**2)
2115            K_mag[K_mag == 0] = 1e-8  # Avoid division by 0
2116    
2117            omega_vals = omega_func(KX, KY)
2118            v_phase = np.real(omega_vals) / K_mag
2119    
2120            dk = k_vals[1] - k_vals[0]
2121            domega_dx = np.gradient(omega_vals, dk, axis=0)
2122            domega_dy = np.gradient(omega_vals, dk, axis=1)
2123            v_group_norm = np.sqrt(np.abs(domega_dx)**2 + np.abs(domega_dy)**2)
2124    
2125            fig, axs = plt.subplots(1, 3, figsize=(18, 5))
2126            im0 = axs[0].imshow(np.real(omega_vals), extent=[-10, 10, -10, 10],
2127                                origin='lower', cmap='viridis')
2128            axs[0].set_title(r'$\omega(k_x, k_y)$')
2129            plt.colorbar(im0, ax=axs[0])
2130    
2131            im1 = axs[1].imshow(v_phase, extent=[-10, 10, -10, 10],
2132                                origin='lower', cmap='plasma')
2133            axs[1].set_title(r'$v_p(k_x, k_y)$')
2134            plt.colorbar(im1, ax=axs[1])
2135    
2136            im2 = axs[2].imshow(v_group_norm, extent=[-10, 10, -10, 10],
2137                                origin='lower', cmap='inferno')
2138            axs[2].set_title(r'$|v_g(k_x, k_y)|$')
2139            plt.colorbar(im2, ax=axs[2])
2140    
2141            for ax in axs:
2142                ax.set_xlabel(r'$k_x$')
2143                ax.set_ylabel(r'$k_y$')
2144                ax.set_aspect('equal')
2145    
2146            plt.tight_layout()
2147            plt.show()
2148    
2149        else:
2150            print("❌ Only 1D and 2D wave analysis supported.")
2151        
2152    def plot_symbol(self, component="abs", k_range=None, cmap="viridis"):
2153        """
2154        Visualize the spectral symbol L(k) or L(kx, ky) in 1D or 2D.
2155    
2156        This method plots the linear operator's symbolic Fourier representation 
2157        either as a function of a single wavenumber k (1D), or two wavenumbers 
2158        kx and ky (2D). The user can choose to display the real part, imaginary part, 
2159        or absolute value of the symbol.
2160    
2161        Parameters
2162        ----------
2163        component : str {'abs', 're', 'im'}
2164            Component of the symbol to visualize:
2165            
2166                - 'abs' : absolute value |a(k)|
2167                - 're'  : real part Re[a(k)]
2168                - 'im'  : imaginary part Im[a(k)]
2169                
2170        k_range : tuple (kmin, kmax, N), optional
2171            Wavenumber range for evaluation:
2172            
2173                - kmin: minimum wavenumber
2174                - kmax: maximum wavenumber
2175                - N: number of sampling points
2176                
2177            If None, defaults to [-10, 10] with high resolution.
2178        cmap : str, optional
2179            Colormap used for 2D surface plots. Default is 'viridis'.
2180    
2181        Raises
2182        ------
2183            ValueError: If the spatial dimension is not 1D or 2D.
2184    
2185        Notes:
2186            - In 1D, the symbol is plotted using a standard 2D line plot.
2187            - In 2D, a 3D surface plot is generated with color-mapped height.
2188            - Symbol evaluation uses self.L(k), which must be defined and callable.
2189        """
2190        print("\n*******************")
2191        print("* Symbol plotting *")
2192        print("*******************\n")
2193        
2194        assert component in ("abs", "re", "im"), "component must be 'abs', 're' or 'im'"
2195        
2196    
2197        if self.dim == 1:
2198            if k_range is None:
2199                k_vals = np.linspace(-10, 10, 1000)
2200            else:
2201                kmin, kmax, N = k_range
2202                k_vals = np.linspace(kmin, kmax, N)
2203            L_vals = self.L(k_vals)
2204    
2205            if component == "re":
2206                vals = np.real(L_vals)
2207                label = "Re[a(k)]"
2208            elif component == "im":
2209                vals = np.imag(L_vals)
2210                label = "Im[a(k)]"
2211            else:
2212                vals = np.abs(L_vals)
2213                label = "|a(k)|"
2214    
2215            plt.plot(k_vals, vals)
2216            plt.xlabel("k")
2217            plt.ylabel(label)
2218            plt.title(f"Spectral symbol: {label}")
2219            plt.grid(True)
2220            plt.show()
2221    
2222        elif self.dim == 2:
2223            if k_range is None:
2224                k_vals = np.linspace(-10, 10, 300)
2225            else:
2226                kmin, kmax, N = k_range
2227                k_vals = np.linspace(kmin, kmax, N)
2228    
2229            KX, KY = np.meshgrid(k_vals, k_vals)
2230            L_vals = self.L(KX, KY)
2231    
2232            if component == "re":
2233                Z = np.real(L_vals)
2234                title = "Re[a(kx, ky)]"
2235            elif component == "im":
2236                Z = np.imag(L_vals)
2237                title = "Im[a(kx, ky)]"
2238            else:
2239                Z = np.abs(L_vals)
2240                title = "|a(kx, ky)|"
2241    
2242            fig = plt.figure(figsize=(8, 6))
2243            ax = fig.add_subplot(111, projection='3d')
2244        
2245            surf = ax.plot_surface(KX, KY, Z, cmap=cmap, edgecolor='none', antialiased=True)
2246            fig.colorbar(surf, ax=ax, shrink=0.6)
2247        
2248            ax.set_xlabel("kx")
2249            ax.set_ylabel("ky")
2250            ax.set_zlabel(title)
2251            ax.set_title(f"2D spectral symbol: {title}")
2252            plt.tight_layout()
2253            plt.show()
2254    
2255        else:
2256            raise ValueError("Only 1D and 2D supported.")
2257
2258    def compute_energy(self):
2259        """
2260        Compute the total energy of the wave equation solution for second-order temporal PDEs. 
2261        The energy is defined as:
2262            E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹ᐟ²u|² ] dx
2263        where L is the linear operator associated with the spatial part of the PDE,
2264        and L¹ᐟ² denotes its square root in Fourier space.
2265    
2266        This method supports both 1D and 2D problems and is only meaningful when 
2267        self.temporal_order == 2 (second-order time derivative).
2268    
2269        Returns
2270        -------
2271        float or None: 
2272            Total energy at current time step. Returns None if the temporal order is not 2 or if no valid velocity data (v_prev) is available.
2273    
2274        Notes
2275        -----
2276        - Uses FFT-based spectral differentiation to compute the spatial contributions.
2277        - Assumes periodic boundary conditions.
2278        - Handles both real and complex-valued solutions.
2279        """
2280        if self.temporal_order != 2 or self.v_prev is None:
2281            return None
2282    
2283        u = self.u_prev
2284        v = self.v_prev
2285    
2286        # Fourier transform of u
2287        u_hat = self.fft(u)
2288    
2289        if self.dim == 1:
2290            # 1D case
2291            L_vals = self.L(self.KX)
2292            sqrt_L = np.sqrt(np.abs(L_vals))
2293            Lu_hat = sqrt_L * u_hat  # Apply sqrt(|L(k)|) in Fourier space
2294            Lu = self.ifft(Lu_hat)
2295    
2296            dx = self.Lx / self.Nx
2297            energy_density = 0.5 * (np.abs(v)**2 + np.abs(Lu)**2)
2298            total_energy = np.sum(energy_density) * dx
2299    
2300        elif self.dim == 2:
2301            # 2D case
2302            L_vals = self.L(self.KX, self.KY)
2303            sqrt_L = np.sqrt(np.abs(L_vals))
2304            Lu_hat = sqrt_L * u_hat
2305            Lu = self.ifft(Lu_hat)
2306    
2307            dx = self.Lx / self.Nx
2308            dy = self.Ly / self.Ny
2309            energy_density = 0.5 * (np.abs(v)**2 + np.abs(Lu)**2)
2310            total_energy = np.sum(energy_density) * dx * dy
2311    
2312        else:
2313            raise ValueError("Unsupported dimension for u.")
2314    
2315        return total_energy
2316
2317    def plot_energy(self, log=False):
2318        """
2319        Plot the time evolution of the total energy for wave equations. 
2320        Visualizes the energy computed during simulation for both 1D and 2D cases. 
2321        Requires temporal_order=2 and prior execution of compute_energy() during solve().
2322        
2323        Parameters:
2324            log : bool
2325                If True, displays energy on a logarithmic scale to highlight exponential decay/growth.
2326        
2327        Notes:
2328            - Energy is defined as E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹⸍²u|² ] dx
2329            - Only available if energy monitoring was activated in solve()
2330            - Automatically skips plotting if no energy data is available
2331        
2332        Displays:
2333            - Time vs. Total Energy plot with grid and legend
2334            - Appropriate axis labels and dimensional context (1D/2D)
2335            - Logarithmic or linear scaling based on input parameter
2336        """
2337        if not hasattr(self, 'energy_history') or not self.energy_history:
2338            print("No energy data recorded. Call compute_energy() within solve().")
2339            return
2340    
2341        # Time vector for plotting
2342        t = np.linspace(0, self.Lt, len(self.energy_history))
2343    
2344        # Create the figure
2345        plt.figure(figsize=(6, 4))
2346        if log:
2347            plt.semilogy(t, self.energy_history, label="Energy (log scale)")
2348        else:
2349            plt.plot(t, self.energy_history, label="Energy")
2350    
2351        # Axis labels and title
2352        plt.xlabel("Time")
2353        plt.ylabel("Total energy")
2354        plt.title("Energy evolution ({}D)".format(self.dim))
2355    
2356        # Display options
2357        plt.grid(True)
2358        plt.legend()
2359        plt.tight_layout()
2360        plt.show()
2361
2362    def show_stationary_solution(self, u=None, component='abs', cmap='viridis'):
2363        """
2364        Display the stationary solution computed by solve_stationary_psiOp.
2365
2366        This method visualizes the solution of a pseudo-differential equation 
2367        solved in stationary mode. It supports both 1D and 2D spatial domains, 
2368        with options to display different components of the solution (real, 
2369        imaginary, absolute value, or phase).
2370
2371        Parameters
2372        ----------
2373        u : ndarray, optional
2374            Precomputed solution array. If None, calls solve_stationary_psiOp() 
2375            to compute the solution.
2376        component : str, optional {'real', 'imag', 'abs', 'angle'}
2377            Component of the complex-valued solution to display:
2378            - 'real': Real part
2379            - 'imag': Imaginary part
2380            - 'abs' : Absolute value (modulus)
2381            - 'angle' : Phase (argument)
2382        cmap : str, optional
2383            Colormap used for 2D visualization (default: 'viridis').
2384
2385        Raises
2386        ------
2387        ValueError
2388            If an invalid component is specified or if the spatial dimension 
2389            is not supported (only 1D and 2D are implemented).
2390
2391        Notes
2392        -----
2393        - In 1D, the solution is displayed using a standard line plot.
2394        - In 2D, the solution is visualized as a 3D surface plot.
2395        """
2396        def get_component(u):
2397            if component == 'real':
2398                return np.real(u)
2399            elif component == 'imag':
2400                return np.imag(u)
2401            elif component == 'abs':
2402                return np.abs(u)
2403            elif component == 'angle':
2404                return np.angle(u)
2405            else:
2406                raise ValueError("Invalid component")
2407                
2408        if u is None:
2409            u = self.solve_stationary_psiOp()
2410
2411        if self.dim == 1:
2412            # Plot the solution in 1D
2413            plt.figure(figsize=(8, 4))
2414            plt.plot(self.x_grid, get_component(u), label=f'{component} of u')
2415            plt.xlabel('x')
2416            plt.ylabel(f'{component} of u')
2417            plt.title('Stationary solution (1D)')
2418            plt.grid(True)
2419            plt.legend()
2420            plt.tight_layout()
2421            plt.show()
2422    
2423        elif self.dim == 2:
2424            fig = plt.figure(figsize=(12, 6))
2425            ax = fig.add_subplot(111, projection='3d')
2426            ax.set_xlabel('x')
2427            ax.set_ylabel('y')
2428            ax.set_zlabel(f'{component.title()} of u')
2429            plt.title('Stationary solution (2D)')    
2430            data0 = get_component(u)
2431            ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2432            plt.tight_layout()
2433            plt.show()
2434    
2435        else:
2436            raise ValueError("Only 1D and 2D display are supported.")
2437
2438    def animate(self, component='abs', overlay='contour', mode='surface'):
2439        """
2440        Create an animated plot of the solution evolution over time.
2441    
2442        This method generates a dynamic visualization of the stored solution frames
2443        `self.frames`. It supports:
2444          - 1D line animation (unchanged),
2445          - 2D surface animation (original behavior, 'surface'),
2446          - 2D image animation using imshow (new, 'imshow') which is faster and
2447            often clearer for large grids.
2448    
2449        Parameters
2450        ----------
2451        component : str, optional, one of {'real', 'imag', 'abs', 'angle'}
2452            Which component of the complex field to visualize:
2453              - 'real'  : Re(u)
2454              - 'imag'  : Im(u)
2455              - 'abs'   : |u|
2456              - 'angle' : arg(u)
2457            Default is 'abs'.
2458    
2459        overlay : str or None, optional, one of {'contour', 'front', None}
2460            For 2D modes only. If None, no overlay is drawn.
2461              - 'contour' : draw contour lines on top (or beneath for 3D surface)
2462              - 'front'   : detect and mark wavefronts using gradient maxima
2463            Default is 'contour'.
2464    
2465        mode : str, optional, one of {'surface', 'imshow'}
2466            2D rendering mode. 'surface' keeps the original 3D surface plot.
2467            'imshow' draws a 2D raster (faster, often more readable).
2468            Default is 'surface' for backward compatibility.
2469    
2470        Returns
2471        -------
2472        FuncAnimation
2473            A Matplotlib `FuncAnimation` instance (you can display it in a notebook
2474            or save it to file).
2475    
2476        Notes
2477        -----
2478        - The method uses the same time-mapping logic as before (linear sampling of
2479          stored frames to animation frames).
2480        - For 'angle' the color scale is fixed between -π and π.
2481        - For other components, color scaling is by default dynamically adapted per
2482          frame in 'imshow' mode (this avoids extreme clipping if amplitudes vary).
2483        - Overlays are updated cleanly: previous contour/scatter artists are removed
2484          before drawing the next frame to avoid memory/visual accumulation.
2485        - Animation interval is 50 ms per frame (unchanged).
2486        """
2487        def get_component(u):
2488            if component == 'real':
2489                return np.real(u)
2490            elif component == 'imag':
2491                return np.imag(u)
2492            elif component == 'abs':
2493                return np.abs(u)
2494            elif component == 'angle':
2495                return np.angle(u)
2496            else:
2497                raise ValueError("Invalid component: choose 'real','imag','abs' or 'angle'")
2498    
2499        print("\n*********************")
2500        print("* Solution plotting *")
2501        print("*********************\n")
2502    
2503        # === Calculate time vector of stored frames ===
2504        save_interval = max(1, self.Nt // self.n_frames)
2505        frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2506    
2507        # === Target times for animation ===
2508        target_times = np.linspace(0, self.Lt, self.n_frames // 2)
2509    
2510        # Map target times to nearest frame indices
2511        frame_indices = [np.argmin(np.abs(frame_times - t)) for t in target_times]
2512    
2513        # -------------------------
2514        # 1D case (unchanged logic)
2515        # -------------------------
2516        if self.dim == 1:
2517            fig, ax = plt.subplots()
2518            initial = get_component(self.frames[0])
2519            line, = ax.plot(self.X, np.real(initial) if np.iscomplexobj(initial) else initial)
2520            ax.set_ylim(np.min(initial), np.max(initial))
2521            ax.set_xlabel('x')
2522            ax.set_ylabel(f'{component} of u')
2523            ax.set_title('Initial condition')
2524            plt.tight_layout()
2525    
2526            def update_1d(frame_number):
2527                frame = frame_indices[frame_number]
2528                ydata = get_component(self.frames[frame])
2529                ydata_real = np.real(ydata) if np.iscomplexobj(ydata) else ydata
2530                line.set_ydata(ydata_real)
2531                ax.set_ylim(np.min(ydata_real), np.max(ydata_real))
2532                current_time = target_times[frame_number]
2533                ax.set_title(f't = {current_time:.2f}')
2534                return (line,)
2535    
2536            ani = FuncAnimation(fig, update_1d, frames=len(target_times), interval=50)
2537            return ani
2538    
2539        # -------------------------
2540        # 2D case
2541        # -------------------------
2542        # Validate mode
2543        if mode not in ('surface', 'imshow'):
2544            raise ValueError("Invalid mode: choose 'surface' or 'imshow'")
2545    
2546        # Common data
2547        data0 = get_component(self.frames[0])
2548    
2549        if mode == 'surface':
2550            # original surface behavior, but ensure clean updates
2551            fig = plt.figure(figsize=(14, 8))
2552            ax = fig.add_subplot(111, projection='3d')
2553            ax.set_xlabel('x')
2554            ax.set_ylabel('y')
2555            ax.set_zlabel(f'{component.title()} of u')
2556            ax.zaxis.labelpad = 0
2557            ax.set_title('Initial condition')
2558    
2559            surf = ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2560            plt.tight_layout()
2561    
2562            def update_surface(frame_number):
2563                frame = frame_indices[frame_number]
2564                current_data = get_component(self.frames[frame])
2565                z_offset = np.max(current_data) + 0.05 * (np.max(current_data) - np.min(current_data))
2566    
2567                ax.clear()
2568                surf_obj = ax.plot_surface(self.X, self.Y, current_data,
2569                                           cmap='viridis',
2570                                           vmin=(-np.pi if component == 'angle' else None),
2571                                           vmax=(np.pi if component == 'angle' else None))
2572                # overlays
2573                if overlay == 'contour':
2574                    # place contours slightly below the surface (use offset)
2575                    try:
2576                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool', offset=z_offset)
2577                    except Exception:
2578                        # fallback: simple contour without offset if not supported
2579                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2580    
2581                elif overlay == 'front':
2582                    dx = self.x_grid[1] - self.x_grid[0]
2583                    dy = self.y_grid[1] - self.y_grid[0]
2584                    # numpy.gradient: axis0 -> y spacing, axis1 -> x spacing
2585                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2586                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2587                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2588                    if np.max(grad_norm) > 0:
2589                        normalized = grad_norm[local_max] / np.max(grad_norm)
2590                    else:
2591                        normalized = np.zeros(np.count_nonzero(local_max))
2592                    colors = cm.plasma(normalized)
2593                    ax.scatter(self.X[local_max], self.Y[local_max],
2594                               z_offset * np.ones_like(self.X[local_max]),
2595                               color=colors, s=10, alpha=0.8)
2596    
2597                ax.set_xlabel('x')
2598                ax.set_ylabel('y')
2599                ax.set_zlabel(f'{component.title()} of u')
2600                current_time = target_times[frame_number]
2601                ax.set_title(f'Solution at t = {current_time:.2f}')
2602                return (surf_obj,)
2603    
2604            ani = FuncAnimation(fig, update_surface, frames=len(target_times), interval=50)
2605            return ani
2606    
2607        else:  # mode == 'imshow'
2608            fig, ax = plt.subplots(figsize=(7, 6))
2609            ax.set_xlabel('x')
2610            ax.set_ylabel('y')
2611            ax.set_title('Initial condition')
2612    
2613            # extent uses physical coordinates so axes show real x/y values
2614            extent = [self.x_grid[0], self.x_grid[-1], self.y_grid[0], self.y_grid[-1]]
2615    
2616            if component == 'angle':
2617                vmin, vmax = -np.pi, np.pi
2618                cmap = 'twilight'
2619            else:
2620                vmin, vmax = np.min(data0), np.max(data0)
2621                cmap = 'viridis'
2622    
2623            im = ax.imshow(data0, extent=extent, origin='lower', cmap=cmap,
2624                           vmin=vmin, vmax=vmax, aspect='auto')
2625            cbar = fig.colorbar(im, ax=ax)
2626            cbar.set_label(f"{component} of u")
2627            plt.tight_layout()
2628    
2629            # containers for dynamic overlay artists (stored on function object)
2630            # update_im.contour_art and update_im.scatter_art will be created dynamically
2631    
2632            def update_im(frame_number):
2633                frame = frame_indices[frame_number]
2634                current_data = get_component(self.frames[frame])
2635    
2636                # update raster
2637                im.set_data(current_data)
2638                if component != 'angle':
2639                    # dynamic per-frame scaling (keeps contrast when amplitude varies)
2640                    cmin = np.nanmin(current_data)
2641                    cmax = np.nanmax(current_data)
2642                    # avoid identical vmin==vmax
2643                    if cmax > cmin:
2644                        im.set_clim(cmin, cmax)
2645    
2646                # remove previous contour if exists
2647                if overlay == 'contour':
2648                    if hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2649                        for coll in update_im.contour_art.collections:
2650                            try:
2651                                coll.remove()
2652                            except Exception:
2653                                pass
2654                        update_im.contour_art = None
2655                    # draw new contours (use meshgrid coords)
2656                    try:
2657                        update_im.contour_art = ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2658                    except Exception:
2659                        # fallback: contour with axis coordinates (x_grid, y_grid)
2660                        Xc, Yc = np.meshgrid(self.x_grid, self.y_grid)
2661                        update_im.contour_art = ax.contour(Xc, Yc, current_data, levels=10, cmap='cool')
2662    
2663                # remove previous scatter if exists
2664                if overlay == 'front':
2665                    if hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2666                        try:
2667                            update_im.scatter_art.remove()
2668                        except Exception:
2669                            pass
2670                        update_im.scatter_art = None
2671    
2672                    dx = self.x_grid[1] - self.x_grid[0]
2673                    dy = self.y_grid[1] - self.y_grid[0]
2674                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2675                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2676                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2677                    if np.max(grad_norm) > 0:
2678                        normalized = grad_norm[local_max] / np.max(grad_norm)
2679                    else:
2680                        normalized = np.zeros(np.count_nonzero(local_max))
2681                    colors = cm.plasma(normalized)
2682                    update_im.scatter_art = ax.scatter(self.X[local_max], self.Y[local_max],
2683                                                       c=colors, s=10, alpha=0.8)
2684    
2685                current_time = target_times[frame_number]
2686                ax.set_title(f'Solution at t = {current_time:.2f}')
2687                # return main image plus any overlay artists present so Matplotlib can redraw them
2688                artists = [im]
2689                if overlay == 'contour' and hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2690                    artists.extend(update_im.contour_art.collections)
2691                if overlay == 'front' and hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2692                    artists.append(update_im.scatter_art)
2693                return tuple(artists)
2694    
2695            ani = FuncAnimation(fig, update_im, frames=len(target_times), interval=50)
2696            return ani
2697
2698    def test(self, u_exact, t_eval=None, norm='relative', threshold=1e-2, component='real'):
2699        """
2700        Test the solver against an exact solution.
2701
2702        This method quantitatively compares the numerical solution with a provided exact solution 
2703        at a specified time using either relative or absolute error norms. It supports both 
2704        stationary and time-dependent problems in 1D and 2D. If enabled, it also generates plots 
2705        of the solution, exact solution, and pointwise error.
2706
2707        Parameters
2708        ----------
2709        u_exact : callable
2710            Exact solution function taking spatial coordinates and optionally time as arguments.
2711        t_eval : float, optional
2712            Time at which to compare solutions. For non-stationary problems, defaults to final time Lt.
2713            Ignored for stationary problems.
2714        norm : str {'relative', 'absolute'}
2715            Type of error norm used in comparison.
2716        threshold : float
2717            Acceptable error threshold; raises an assertion if exceeded.
2718        plot : bool
2719            Whether to display visual comparison plots (default: True).
2720        component : str {'real', 'imag', 'abs'}
2721            Component of the solution to compare and visualize.
2722
2723        Raises
2724        ------
2725        ValueError
2726            If unsupported dimension is encountered or requested evaluation time exceeds simulation duration.
2727        AssertionError
2728            If computed error exceeds the given threshold.
2729
2730        Prints
2731        ------
2732        - Information about the closest available frame to the requested evaluation time.
2733        - Computed error value and comparison to threshold.
2734
2735        Notes
2736        -----
2737        - For time-dependent problems, the solution is extracted from precomputed frames.
2738        - Plots are adapted to spatial dimension: line plots for 1D, image plots for 2D.
2739        - The method ensures consistent handling of real, imaginary, and magnitude components.
2740        """
2741        if self.is_stationary:
2742            print("Testing a stationary solution.")
2743            u_num = self.u
2744    
2745            # Compute exact solution
2746            if self.dim == 1:
2747                u_ex = u_exact(self.X)
2748            elif self.dim == 2:
2749                u_ex = u_exact(self.X, self.Y)
2750            else:
2751                raise ValueError("Unsupported dimension.")
2752            actual_t = None
2753        else:
2754            if t_eval is None:
2755                t_eval = self.Lt
2756    
2757            save_interval = max(1, self.Nt // self.n_frames)
2758            frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2759            frame_index = np.argmin(np.abs(frame_times - t_eval))
2760            actual_t = frame_times[frame_index]
2761            print(f"Closest available time to t_eval={t_eval}: {actual_t}")
2762    
2763            if frame_index >= len(self.frames):
2764                raise ValueError(f"Time t = {t_eval} exceeds simulation duration.")
2765    
2766            u_num = self.frames[frame_index]
2767    
2768            # Compute exact solution at the actual time
2769            if self.dim == 1:
2770                u_ex = u_exact(self.X, actual_t)
2771            elif self.dim == 2:
2772                u_ex = u_exact(self.X, self.Y, actual_t)
2773            else:
2774                raise ValueError("Unsupported dimension.")
2775    
2776        # Select component
2777        if component == 'real':
2778            diff = np.real(u_num) - np.real(u_ex)
2779            ref = np.real(u_ex)
2780        elif component == 'imag':
2781            diff = np.imag(u_num) - np.imag(u_ex)
2782            ref = np.imag(u_ex)
2783        elif component == 'abs':
2784            diff = np.abs(u_num) - np.abs(u_ex)
2785            ref = np.abs(u_ex)
2786        else:
2787            raise ValueError("Invalid component.")
2788    
2789        # Compute error
2790        if norm == 'relative':
2791            error = np.linalg.norm(diff) / np.linalg.norm(ref)
2792        elif norm == 'absolute':
2793            error = np.linalg.norm(diff)
2794        else:
2795            raise ValueError("Unknown norm type.")
2796    
2797        label_time = f"t = {actual_t}" if actual_t is not None else ""
2798        print(f"Test error {label_time}: {error:.3e}")
2799        assert error < threshold, f"Error too large {label_time}: {error:.3e}"
2800    
2801        # Plot
2802        if self.plot:
2803            if self.dim == 1:
2804                plt.figure(figsize=(12, 6))
2805                plt.subplot(2, 1, 1)
2806                plt.plot(self.X, np.real(u_num), label='Numerical')
2807                plt.plot(self.X, np.real(u_ex), '--', label='Exact')
2808                plt.title(f'Solution {label_time}, error = {error:.2e}')
2809                plt.legend()
2810                plt.grid()
2811    
2812                plt.subplot(2, 1, 2)
2813                plt.plot(self.X, np.abs(diff), color='red')
2814                plt.title('Absolute Error')
2815                plt.grid()
2816                plt.tight_layout()
2817                plt.show()
2818            else:
2819                extent = [-self.Lx/2, self.Lx/2, -self.Ly/2, self.Ly/2]
2820                plt.figure(figsize=(15, 5))
2821                plt.subplot(1, 3, 1)
2822                plt.title("Numerical Solution")
2823                plt.imshow(np.abs(u_num), origin='lower', extent=extent, cmap='viridis')
2824                plt.colorbar()
2825    
2826                plt.subplot(1, 3, 2)
2827                plt.title("Exact Solution")
2828                plt.imshow(np.abs(u_ex), origin='lower', extent=extent, cmap='viridis')
2829                plt.colorbar()
2830    
2831                plt.subplot(1, 3, 3)
2832                plt.title(f"Error (Norm = {error:.2e})")
2833                plt.imshow(np.abs(diff), origin='lower', extent=extent, cmap='inferno')
2834                plt.colorbar()
2835                plt.tight_layout()
2836                plt.show()

A partial differential equation (PDE) solver based on spectral methods using Fourier transforms.

This solver supports symbolic specification of PDEs via SymPy and numerical solution using high-order spectral techniques. It is designed for both linear and nonlinear time-dependent PDEs, as well as stationary pseudo-differential problems.

Key Features:

  • Symbolic PDE parsing using SymPy expressions
  • 1D and 2D spatial domains with periodic boundary conditions
  • Fourier-based spectral discretization with dealiasing
  • Temporal integration schemes:
    • Default exponential time stepping
    • ETD-RK4 (Exponential Time Differencing Runge-Kutta of 4th order)
  • Nonlinear terms handled through pseudo-spectral evaluation
  • Built-in tools for:
    • Visualization of solutions and error surfaces
    • Symbol analysis of linear and pseudo-differential operators
    • Microlocal analysis (e.g., Hamiltonian flows)
    • CFL condition checking and numerical stability diagnostics

Supported Operators:

  • Linear differential and pseudo-differential operators
  • Nonlinear terms up to second order in derivatives
  • Symbolic operator composition and adjoints
  • Asymptotic inversion of elliptic operators for stationary problems

Example Usage:

>>> from PDESolver import *
>>> u = Function('u')
>>> t, x = symbols('t x')
>>> eq = Eq(diff(u(t, x), t), diff(u(t, x), x, 2) + u(t, x)**2)
>>> def initial(x): return np.sin(x)
>>> solver = PDESolver(eq)
>>> solver.setup(Lx=2*np.pi, Nx=128, Lt=1.0, Nt=1000, initial_condition=initial)
>>> solver.solve()
>>> ani = solver.animate()
>>> HTML(ani.to_jshtml())  # Display animation in Jupyter notebook
PDESolver(equation, time_scheme='default', dealiasing_ratio=0.6666666666666666)
 70    def __init__(self, equation, time_scheme='default', dealiasing_ratio=2/3):
 71        """
 72        Initialize the PDE solver with a given equation.
 73
 74        This method analyzes the input partial differential equation (PDE), 
 75        identifies the unknown function and its dependencies, determines whether 
 76        the problem is stationary or time-dependent, and prepares symbolic and 
 77        numerical structures for solving in spectral space.
 78
 79        Supported features:
 80        
 81        - 1D and 2D problems
 82        - Time-dependent and stationary equations
 83        - Linear and nonlinear terms
 84        - Pseudo-differential operators via `psiOp`
 85        - Source terms and boundary conditions
 86
 87        The equation is parsed to extract linear, nonlinear, source, and 
 88        pseudo-differential components. Symbolic manipulation is used to derive 
 89        the Fourier representation of linear operators when applicable.
 90
 91        Parameters
 92        ----------
 93        equation : sympy.Eq 
 94            The PDE expressed as a SymPy equation.
 95        time_scheme : str
 96            Temporal integration scheme: 
 97                - 'default' for exponential 
 98                - time-stepping or 'ETD-RK4' for fourth-order exponential 
 99                - time differencing Runge–Kutta.
100        dealiasing_ratio : float
101            Fraction of high-frequency modes to zero out 
102            during dealiasing (e.g., 2/3 for standard truncation).
103
104        Attributes initialized:
105        
106        - self.u: the unknown function (e.g., u(t, x))
107        - self.dim: spatial dimension (1 or 2)
108        - self.spatial_vars: list of spatial variables (e.g., [x] or [x, y])
109        - self.is_stationary: boolean indicating if the problem is stationary
110        - self.linear_terms: dictionary mapping derivative orders to coefficients
111        - self.nonlinear_terms: list of nonlinear expressions
112        - self.source_terms: list of source functions
113        - self.pseudo_terms: list of pseudo-differential operator expressions
114        - self.has_psi: boolean indicating presence of pseudo-differential operators
115        - self.fft / self.ifft: appropriate FFT routines based on spatial dimension
116        - self.kx, self.ky: symbolic wavenumber variables for Fourier space
117
118        Raises:
119            ValueError: If the equation does not contain exactly one unknown function,
120                        if unsupported dimensions are detected, or invalid dependencies.
121        """
122        self.time_scheme = time_scheme # 'default'  or 'ETD-RK4'
123        self.dealiasing_ratio = dealiasing_ratio
124        
125        print("\n*********************************")
126        print("* Partial differential equation *")
127        print("*********************************\n")
128        pprint(equation, num_columns=NUM_COLS)
129        
130        # Extract symbols and function from the equation
131        functions = equation.atoms(Function)
132        
133        # Ignore the wrappers psiOp and Op
134        excluded_wrappers = {'psiOp', 'Op'}
135        
136        # Extract the candidate fonctions (excluding wrappers)
137        candidate_functions = [
138            f for f in functions 
139            if f.func.__name__ not in excluded_wrappers
140        ]
141        
142        # Keep only user functions (u(x), u(x, t), etc.)
143        candidate_functions = [
144            f for f in functions
145            if isinstance(f, AppliedUndef)
146        ]
147        
148        # Stationary detection: no dependence on t
149        self.is_stationary = all(
150            not any(str(arg) == 't' for arg in f.args)
151            for f in candidate_functions
152        )
153        
154        if len(candidate_functions) != 1:
155            print("candidate_functions :", candidate_functions)
156            raise ValueError("The equation must contain exactly one unknown function")
157        
158        self.u = candidate_functions[0]
159
160        self.u_eq = self.u
161
162        args = self.u.args
163        
164        if self.is_stationary:
165            if len(args) not in (1, 2):
166                raise ValueError("Stationary problems must depend on 1 or 2 spatial variables")
167            self.spatial_vars = args
168        else:
169            if len(args) < 2 or len(args) > 3:
170                raise ValueError("The function must depend on t and at least one spatial variable (x [, y])")
171            self.t = args[0]
172            self.spatial_vars = args[1:]
173
174        self.dim = len(self.spatial_vars)
175        if self.dim == 1:
176            self.x = self.spatial_vars[0]
177            self.y = None
178        elif self.dim == 2:
179            self.x, self.y = self.spatial_vars
180        else:
181            raise ValueError("Only 1D and 2D problems are supported.")
182
183        if self.dim == 1:
184            self.fft = partial(fft, workers=FFT_WORKERS)
185            self.ifft = partial(ifft, workers=FFT_WORKERS)
186        else:
187            self.fft = partial(fft2, workers=FFT_WORKERS)
188            self.ifft = partial(ifft2, workers=FFT_WORKERS)
189            
190        # Parse the equation
191        self.linear_terms = {}
192        self.nonlinear_terms = []
193        self.symbol_terms = []
194        self.source_terms = []
195        self.pseudo_terms = []
196        self.temporal_order = 0  # Order of the temporal derivative
197        self.linear_terms, self.nonlinear_terms, self.symbol_terms, self.source_terms, self.pseudo_terms = self.parse_equation(equation)
198        # flag : pseudo‑differential operator present ?
199        self.has_psi = bool(self.pseudo_terms)
200        if self.has_psi:
201            print('⚠️  Pseudo‑differential operator detected: all other linear terms have been rejected.')
202            self.is_spatial = False
203            for coeff, expr in self.pseudo_terms:
204                if expr.has(self.x) or (self.dim == 2 and expr.has(self.y)):
205                    self.is_spatial = True
206                    break
207    
208        if self.dim == 1:
209            self.kx = symbols('kx')
210        elif self.dim == 2:
211            self.kx, self.ky = symbols('kx ky')
212    
213        # Compute linear operator
214        if not self.is_stationary:
215            self.compute_linear_operator()
216        else:
217            self.psi_ops = []
218            for coeff, sym_expr in self.pseudo_terms:
219                psi = PseudoDifferentialOperator(sym_expr, self.spatial_vars, self.u, mode='symbol')
220                self.psi_ops.append((coeff, psi))

Initialize the PDE solver with a given equation.

This method analyzes the input partial differential equation (PDE), identifies the unknown function and its dependencies, determines whether the problem is stationary or time-dependent, and prepares symbolic and numerical structures for solving in spectral space.

Supported features:

  • 1D and 2D problems
  • Time-dependent and stationary equations
  • Linear and nonlinear terms
  • Pseudo-differential operators via psiOp
  • Source terms and boundary conditions

The equation is parsed to extract linear, nonlinear, source, and pseudo-differential components. Symbolic manipulation is used to derive the Fourier representation of linear operators when applicable.

Parameters

equation : sympy.Eq The PDE expressed as a SymPy equation. time_scheme : str Temporal integration scheme: - 'default' for exponential - time-stepping or 'ETD-RK4' for fourth-order exponential - time differencing Runge–Kutta. dealiasing_ratio : float Fraction of high-frequency modes to zero out during dealiasing (e.g., 2/3 for standard truncation).

Attributes initialized:

  • self.u: the unknown function (e.g., u(t, x))
  • self.dim: spatial dimension (1 or 2)
  • self.spatial_vars: list of spatial variables (e.g., or [x, y])
  • self.is_stationary: boolean indicating if the problem is stationary
  • self.linear_terms: dictionary mapping derivative orders to coefficients
  • self.nonlinear_terms: list of nonlinear expressions
  • self.source_terms: list of source functions
  • self.pseudo_terms: list of pseudo-differential operator expressions
  • self.has_psi: boolean indicating presence of pseudo-differential operators
  • self.fft / self.ifft: appropriate FFT routines based on spatial dimension
  • self.kx, self.ky: symbolic wavenumber variables for Fourier space

Raises: ValueError: If the equation does not contain exactly one unknown function, if unsupported dimensions are detected, or invalid dependencies.

time_scheme
dealiasing_ratio
is_stationary
u
u_eq
dim
linear_terms
nonlinear_terms
symbol_terms
source_terms
pseudo_terms
temporal_order
has_psi
def parse_equation(self, equation):
222    def parse_equation(self, equation):
223        """
224        Parse the PDE to separate linear and nonlinear terms, symbolic operators (Op), 
225        source terms, and pseudo-differential operators (psiOp).
226    
227        This method rewrites the input equation in standard form (lhs - rhs = 0),
228        expands it, and classifies each term into one of the following categories:
229        
230        - Linear terms involving derivatives or the unknown function u
231        - Nonlinear terms (products with u, powers of u, etc.)
232        - Symbolic pseudo-differential operators (Op)
233        - Source terms (independent of u)
234        - Pseudo-differential operators (psiOp)
235    
236        Parameters
237            equation (sympy.Eq): The partial differential equation to be analyzed. 
238                                 Can be provided as an Eq object or a sympy expression.
239    
240        Returns:
241            tuple: A 5-tuple containing:
242            
243                - linear_terms (dict): Mapping from derivative/function to coefficient.
244                - nonlinear_terms (list): List of terms classified as nonlinear.
245                - symbol_terms (list): List of (coefficient, symbolic operator) pairs.
246                - source_terms (list): List of terms independent of the unknown function.
247                - pseudo_terms (list): List of (coefficient, pseudo-differential symbol) pairs.
248    
249        Notes:
250            - If `psiOp` is present in the equation, expansion is skipped for safety.
251            - When `psiOp` is used, only nonlinear terms, source terms, and possibly 
252              a time derivative are allowed; other linear terms and symbolic operators 
253              (Op) are forbidden.
254            - Classification logic includes:
255                - Detection of nonlinear structures like products or powers of u
256                - Mixed terms involving both u and its derivatives
257                - External symbolic operators (Op) and pseudo-differential operators (psiOp)
258        """
259        def is_nonlinear_term(term, u_func):
260            # If the term contains functions (Abs, sin, exp, ...) applied to u
261            if term.has(u_func):
262                for sub in preorder_traversal(term):
263                    if isinstance(sub, Function) and sub.has(u_func) and sub.func != u_func.func:
264                        return True
265            # If the term contains a nonlinear power of u
266            if term.has(Pow):
267                for pow_term in term.atoms(Pow):
268                    if pow_term.base == u_func and pow_term.exp != 1:
269                        return True
270            # If the term is a product containing u and its derivative
271            if term.func == Mul:
272                factors = term.args
273                has_u = any((f.has(u_func) and not isinstance(f, Derivative) for f in factors))
274                has_derivative = any((isinstance(f, Derivative) and f.expr.func == u_func.func for f in factors))
275                if has_u and has_derivative:
276                    return True
277            return False
278    
279        print("\n********************")
280        print("* Equation parsing *")
281        print("********************\n")
282    
283        if isinstance(equation, Eq):
284            lhs = equation.lhs - equation.rhs
285        else:
286            lhs = equation
287    
288        print(f"\nEquation rewritten in standard form: {lhs}")
289        if lhs.has(psiOp):
290            print("⚠️ psiOp detected: skipping expansion for safety")
291            lhs_expanded = lhs
292        else:
293            lhs_expanded = expand(lhs)
294    
295        print(f"\nExpanded equation: {lhs_expanded}")
296    
297        linear_terms = {}
298        nonlinear_terms = []
299        symbol_terms = []
300        source_terms = []
301        pseudo_terms = []
302    
303        for term in lhs_expanded.as_ordered_terms():
304            print(f"Analyzing term: {term}")
305    
306            if isinstance(term, psiOp):
307                expr = term.args[0]
308                pseudo_terms.append((1, expr))
309                print("  --> Classified as pseudo linear term (psiOp)")
310                continue
311    
312            # Otherwise, look for psiOp inside (general case)
313            if term.has(psiOp):
314                psiops = term.atoms(psiOp)
315                for psi in psiops:
316                    try:
317                        coeff = simplify(term / psi)
318                        expr = psi.args[0]
319                        pseudo_terms.append((coeff, expr))
320                        print("  --> Classified as pseudo linear term (psiOp)")
321                    except Exception as e:
322                        print(f"  ⚠️ Failed to extract psiOp coefficient in term: {term}")
323                        print(f"     Reason: {e}")
324                        nonlinear_terms.append(term)
325                        print("  --> Fallback: classified as nonlinear")
326                continue
327    
328            if term.has(Op):
329                ops = term.atoms(Op)
330                for op in ops:
331                    coeff = term / op
332                    expr = op.args[0]
333                    symbol_terms.append((coeff, expr))
334                    print("  --> Classified as symbolic linear term (Op)")
335                continue
336    
337            if is_nonlinear_term(term, self.u):
338                nonlinear_terms.append(term)
339                print("  --> Classified as nonlinear")
340                continue
341    
342            derivs = term.atoms(Derivative)
343            if derivs:
344                deriv = derivs.pop()
345                coeff = term / deriv
346                linear_terms[deriv] = linear_terms.get(deriv, 0) + coeff
347                print(f"  Derivative found: {deriv}")
348                print("  --> Classified as linear")
349            elif self.u in term.atoms(Function):
350                coeff = term.as_coefficients_dict().get(self.u, 1)
351                linear_terms[self.u] = linear_terms.get(self.u, 0) + coeff
352                print("  --> Classified as linear")
353            else:
354                source_terms.append(term)
355                print("  --> Classified as source term")
356    
357        print(f"Final linear terms: {linear_terms}")
358        print(f"Final nonlinear terms: {nonlinear_terms}")
359        print(f"Symbol terms: {symbol_terms}")
360        print(f"Pseudo terms: {pseudo_terms}")
361        print(f"Source terms: {source_terms}")
362    
363        if pseudo_terms:
364            # Check if a time derivative is present among the linear terms
365            has_time_derivative = any(
366                isinstance(term, Derivative) and self.t in [v for v, _  in term.variable_count]
367                for term in linear_terms
368            )
369            # Extract non-temporal linear terms
370            invalid_linear_terms = {
371                term: coeff for term, coeff in linear_terms.items()
372                if not (
373                    isinstance(term, Derivative)
374                    and self.t in [v for v, _  in term.variable_count]
375                )
376                and term != self.u  # exclusion of the simple u term (without derivative)
377            }
378    
379            if invalid_linear_terms or symbol_terms:
380                raise ValueError(
381                    "When psiOp is used, only nonlinear terms, source terms, "
382                    "and possibly a time derivative are allowed. "
383                    "Other linear terms and Ops are forbidden."
384                )
385    
386        return linear_terms, nonlinear_terms, symbol_terms, source_terms, pseudo_terms

Parse the PDE to separate linear and nonlinear terms, symbolic operators (Op), source terms, and pseudo-differential operators (psiOp).

This method rewrites the input equation in standard form (lhs - rhs = 0), expands it, and classifies each term into one of the following categories:

  • Linear terms involving derivatives or the unknown function u
  • Nonlinear terms (products with u, powers of u, etc.)
  • Symbolic pseudo-differential operators (Op)
  • Source terms (independent of u)
  • Pseudo-differential operators (psiOp)

Parameters equation (sympy.Eq): The partial differential equation to be analyzed. Can be provided as an Eq object or a sympy expression.

Returns: tuple: A 5-tuple containing:

    - linear_terms (dict): Mapping from derivative/function to coefficient.
    - nonlinear_terms (list): List of terms classified as nonlinear.
    - symbol_terms (list): List of (coefficient, symbolic operator) pairs.
    - source_terms (list): List of terms independent of the unknown function.
    - pseudo_terms (list): List of (coefficient, pseudo-differential symbol) pairs.

Notes: - If psiOp is present in the equation, expansion is skipped for safety. - When psiOp is used, only nonlinear terms, source terms, and possibly a time derivative are allowed; other linear terms and symbolic operators (Op) are forbidden. - Classification logic includes: - Detection of nonlinear structures like products or powers of u - Mixed terms involving both u and its derivatives - External symbolic operators (Op) and pseudo-differential operators (psiOp)

def compute_linear_operator(self):
389    def compute_linear_operator(self):
390        """
391        Compute the symbolic Fourier representation L(k) of the linear operator 
392        derived from the linear part of the PDE.
393    
394        This method constructs a dispersion relation by applying each symbolic derivative
395        to a plane wave exp(i(k·x - ωt)) and extracting the resulting expression.
396        It handles arbitrary derivative combinations and includes symbolic and
397        pseudo-differential terms.
398    
399        Steps:
400        -------
401        1. Construct a plane wave φ(x, t) = exp(i(k·x - ωt)).
402        2. Apply each term from self.linear_terms to φ.
403        3. Normalize by φ and simplify to obtain L(k).
404        4. Include symbolic terms (e.g., psiOp) if present.
405        5. Detect the temporal order from the dispersion relation.
406        6. Build the numerical function L(k) via lambdify.
407    
408        Sets:
409        -----
410        - self.L_symbolic : sympy.Expr
411            Symbolic form of L(k).
412        - self.L : callable
413            Numerical function of L(kx[, ky]).
414        - self.omega : callable or None
415            Frequency root ω(k), if available.
416        - self.temporal_order : int
417            Order of time derivatives detected.
418        - self.psi_ops : list of (coeff, PseudoDifferentialOperator)
419            Pseudo-differential terms present in the equation.
420    
421        Raises:
422        -------
423        ValueError if the dimension is unsupported or the dispersion relation fails.
424        """
425        print("\n*******************************")
426        print("* Linear operator computation *")
427        print("*******************************\n")
428    
429        # --- Step 1: symbolic variables ---
430        omega = symbols("omega")
431        if self.dim == 1:
432            kvars = [symbols("kx")]
433            space_vars = [self.x]
434        elif self.dim == 2:
435            kvars = symbols("kx ky")
436            space_vars = [self.x, self.y]
437        else:
438            raise ValueError("Only 1D and 2D are supported.")
439    
440        kdict = dict(zip(space_vars, kvars))
441        self.k_symbols = kvars
442    
443        # Plane wave expression
444        phase = sum(k * x for k, x in zip(kvars, space_vars)) - omega * self.t
445        plane_wave = exp(I * phase)
446    
447        # --- Step 2: build lhs expression from linear terms ---
448        lhs = 0
449        for deriv, coeff in self.linear_terms.items():
450            if isinstance(deriv, Derivative):
451                total_factor = 1
452                for var, n in deriv.variable_count:
453                    if var == self.t:
454                        total_factor *= (-I * omega)**n
455                    elif var in kdict:
456                        total_factor *= (I * kdict[var])**n
457                    else:
458                        raise ValueError(f"Unknown variable {var} in derivative")
459                lhs += coeff * total_factor * plane_wave
460            elif deriv == self.u:
461                lhs += coeff * plane_wave
462            else:
463                raise ValueError(f"Unsupported linear term: {deriv}")
464    
465        # --- Step 3: dispersion relation ---
466        equation = simplify(lhs / plane_wave)
467        print("\nCharacteristic equation before symbol treatment:")
468        pprint(equation, num_columns=NUM_COLS)
469
470        print("\n--- Symbolic symbol analysis ---")
471        symb_omega = 0
472        symb_k = 0
473        
474        for coeff, symbol in self.symbol_terms:
475            if symbol.has(omega):
476                # Ajouter directement les termes dépendant de omega
477                symb_omega += coeff * symbol
478            elif any(symbol.has(k) for k in self.k_symbols):
479                 symb_k += coeff * symbol.subs(dict(zip(symbol.free_symbols, self.k_symbols)))
480
481        print(f"symb_omega: {symb_omega}")
482        print(f"symb_k: {symb_k}")
483        
484        equation = equation + symb_omega + symb_k         
485
486        print("\nRaw characteristic equation:")
487        pprint(equation, num_columns=NUM_COLS)
488
489        # Temporal derivative order detection
490        try:
491            poly_eq = Eq(equation, 0)
492            poly = poly_eq.lhs.as_poly(omega)
493            self.temporal_order = poly.degree() if poly else 0
494        except Exception as e:
495            warnings.warn(f"Could not determine temporal order: {e}", RuntimeWarning)
496            self.temporal_order = 0
497        print(f"Temporal order from dispersion relation: {self.temporal_order}")
498        print('self.pseudo_terms = ', self.pseudo_terms)
499        if self.pseudo_terms:
500            coeff_time = 1
501            for term, coeff in self.linear_terms.items():
502                if isinstance(term, Derivative) and any(var == self.t for var, _  in term.variable_count):
503                    coeff_time = coeff
504                    print(f"✅ Time derivative coefficient detected: {coeff_time}")
505            self.psi_ops = []
506            for coeff, sym_expr in self.pseudo_terms:
507                # expr est le Sympy expr. différentiel, var_x la liste [x] ou [x,y]
508                psi = PseudoDifferentialOperator(sym_expr / coeff_time, self.spatial_vars, self.u, mode='symbol')
509                
510                self.psi_ops.append((coeff, psi))
511        else:
512            dispersion = solve(Eq(equation, 0), omega)
513            if not dispersion:
514                raise ValueError("No solution found for omega")
515            print("\n--- Solutions found ---")
516            pprint(dispersion, num_columns=NUM_COLS)
517        
518            if self.temporal_order == 2:
519                omega_expr = simplify(sqrt(dispersion[0]**2))
520                self.omega_symbolic = omega_expr
521                self.omega = lambdify(self.k_symbols, omega_expr, "numpy")
522                self.L_symbolic = -omega_expr**2
523            else:
524                self.L_symbolic = -I * dispersion[0]
525        
526        
527            self.L = lambdify(self.k_symbols, self.L_symbolic, "numpy")
528  
529            print("\n--- Final linear operator ---")
530            pprint(self.L_symbolic, num_columns=NUM_COLS)   

Compute the symbolic Fourier representation L(k) of the linear operator derived from the linear part of the PDE.

This method constructs a dispersion relation by applying each symbolic derivative to a plane wave exp(i(k·x - ωt)) and extracting the resulting expression. It handles arbitrary derivative combinations and includes symbolic and pseudo-differential terms.

Steps:

  1. Construct a plane wave φ(x, t) = exp(i(k·x - ωt)).
  2. Apply each term from self.linear_terms to φ.
  3. Normalize by φ and simplify to obtain L(k).
  4. Include symbolic terms (e.g., psiOp) if present.
  5. Detect the temporal order from the dispersion relation.
  6. Build the numerical function L(k) via lambdify.

Sets:

  • self.L_symbolic : sympy.Expr Symbolic form of L(k).
  • self.L : callable Numerical function of L(kx[, ky]).
  • self.omega : callable or None Frequency root ω(k), if available.
  • self.temporal_order : int Order of time derivatives detected.
  • self.psi_ops : list of (coeff, PseudoDifferentialOperator) Pseudo-differential terms present in the equation.

Raises:

ValueError if the dimension is unsupported or the dispersion relation fails.

def linear_rhs(self, u, is_v=False):
532    def linear_rhs(self, u, is_v=False):
533        """
534        Apply the linear operator (in Fourier space) to the field u or v.
535
536        Parameters
537        ----------
538        u : np.ndarray
539            Input solution array.
540        is_v : bool
541            Whether to apply the operator to v instead of u.
542
543        Returns
544        -------
545        np.ndarray
546            Result of applying the linear operator.
547        """
548        if self.dim == 1:
549            self.symbol_u = np.array(self.L(self.KX), dtype=np.complex128)
550            self.symbol_v = self.symbol_u  # même opérateur pour u et v
551        elif self.dim == 2:
552            self.symbol_u = np.array(self.L(self.KX, self.KY), dtype=np.complex128)
553            self.symbol_v = self.symbol_u
554        u_hat = self.fft(u)
555        u_hat *= self.symbol_v if is_v else self.symbol_u
556        u_hat *= self.dealiasing_mask
557        return self.ifft(u_hat)

Apply the linear operator (in Fourier space) to the field u or v.

Parameters

u : np.ndarray Input solution array. is_v : bool Whether to apply the operator to v instead of u.

Returns

np.ndarray Result of applying the linear operator.

def setup( self, Lx, Ly=None, Nx=None, Ny=None, Lt=1.0, Nt=100, boundary_condition='periodic', initial_condition=None, initial_velocity=None, n_frames=100, plot=True):
559    def setup(self, Lx, Ly=None, Nx=None, Ny=None, Lt=1.0, Nt=100, boundary_condition='periodic',
560              initial_condition=None, initial_velocity=None, n_frames=100, plot=True):
561        """
562        Configure the spatial/temporal grid and initialize the solution field.
563    
564        This method sets up the computational domain, initializes spatial and temporal grids,
565        applies boundary conditions, and prepares symbolic and numerical operators.
566        It also performs essential analyses such as:
567        
568            - CFL condition verification (for stability)
569            - Symbol analysis (e.g., dispersion relation, regularity)
570            - Wave propagation analysis for second-order equations
571    
572        If pseudo-differential operators (ψOp) are present, symbolic analysis is skipped
573        in favor of interactive exploration via `interactive_symbol_analysis`.
574    
575        Parameters
576        ----------
577        Lx : float
578            Size of the spatial domain along x-axis.
579        Ly : float, optional
580            Size of the spatial domain along y-axis (for 2D problems).
581        Nx : int
582            Number of spatial points along x-axis.
583        Ny : int, optional
584            Number of spatial points along y-axis (for 2D problems).
585        Lt : float, default=1.0
586            Total simulation time.
587        Nt : int, default=100
588            Number of time steps.
589        initial_condition : callable
590            Function returning the initial state u(x, 0) or u(x, y, 0).
591        initial_velocity : callable, optional
592            Function returning the initial time derivative ∂ₜu(x, 0) or ∂ₜu(x, y, 0),
593            required for second-order equations.
594        n_frames : int, default=100
595            Number of time frames to store during simulation for visualization or output.
596    
597        Raises
598        ------
599        ValueError
600            If mandatory parameters are missing (e.g., Nx not given in 1D, Ly/Ny not given in 2D).
601    
602        Notes
603        -----
604        - The spatial discretization assumes periodic boundary conditions by default.
605        - Fourier transforms are computed using real-to-complex FFTs (`scipy.fft.fft`, `fft2`).
606        - Frequency arrays (`KX`, `KY`) are defined following standard spectral conventions.
607        - Dealiasing is applied using a sharp cutoff filter at a fraction of the maximum frequency.
608        - For second-order equations, initial acceleration is derived from the governing operator.
609        - Symbolic analysis includes plotting of the symbol's real/imaginary/absolute values
610          and dispersion relation.
611    
612        See Also
613        --------
614        setup_1D : Sets up internal variables for one-dimensional problems.
615        setup_2D : Sets up internal variables for two-dimensional problems.
616        initialize_conditions : Applies initial data and enforces compatibility.
617        check_cfl_condition : Verifies time step against stability constraints.
618        plot_symbol : Visualizes the linear operator’s symbol in frequency space.
619        analyze_wave_propagation : Analyzes group velocity.
620        interactive_symbol_analysis : Interactive tools for ψOp-based equations.
621        """
622        
623        # Temporal parameters
624        self.Lt, self.Nt = Lt, Nt
625        self.dt = Lt / Nt
626        self.n_frames = n_frames
627        self.frames = []
628        self.initial_condition = initial_condition
629        self.boundary_condition = boundary_condition
630        self.plot = plot
631
632        if self.boundary_condition == 'dirichlet' and not self.has_psi:
633            raise ValueError(
634                "Dirichlet boundary conditions require the equation to be defined via a pseudo-differential operator (psiOp). "
635                "Please provide an equation involving psiOp for non-periodic boundary treatment."
636            )
637    
638        # Dimension checks
639        if self.dim == 1:
640            if Nx is None:
641                raise ValueError("Nx must be specified in 1D.")
642            self.setup_1D(Lx, Nx)
643        else:
644            if None in (Ly, Ny):
645                raise ValueError("In 2D, Ly and Ny must be provided.")
646            self.setup_2D(Lx, Ly, Nx, Ny)
647    
648        # Initialization of solution and velocities
649        if not self.is_stationary:
650            self.initialize_conditions(initial_condition, initial_velocity)
651            
652        # Symbol analysis if present
653        if self.has_psi:
654            print("⚠️ For psiOp, use interactive_symbol_analysis.")
655        else:
656            if self.L_symbolic == 0:
657                print("⚠️ Linear operator is null.")
658            else:
659                self.check_cfl_condition()
660                self.check_symbol_conditions()
661                if plot:
662                	self.plot_symbol()
663                	if self.temporal_order == 2:
664                		self.analyze_wave_propagation()

Configure the spatial/temporal grid and initialize the solution field.

This method sets up the computational domain, initializes spatial and temporal grids, applies boundary conditions, and prepares symbolic and numerical operators. It also performs essential analyses such as:

- CFL condition verification (for stability)
- Symbol analysis (e.g., dispersion relation, regularity)
- Wave propagation analysis for second-order equations

If pseudo-differential operators (ψOp) are present, symbolic analysis is skipped in favor of interactive exploration via interactive_symbol_analysis.

Parameters

Lx : float Size of the spatial domain along x-axis. Ly : float, optional Size of the spatial domain along y-axis (for 2D problems). Nx : int Number of spatial points along x-axis. Ny : int, optional Number of spatial points along y-axis (for 2D problems). Lt : float, default=1.0 Total simulation time. Nt : int, default=100 Number of time steps. initial_condition : callable Function returning the initial state u(x, 0) or u(x, y, 0). initial_velocity : callable, optional Function returning the initial time derivative ∂ₜu(x, 0) or ∂ₜu(x, y, 0), required for second-order equations. n_frames : int, default=100 Number of time frames to store during simulation for visualization or output.

Raises

ValueError If mandatory parameters are missing (e.g., Nx not given in 1D, Ly/Ny not given in 2D).

Notes

  • The spatial discretization assumes periodic boundary conditions by default.
  • Fourier transforms are computed using real-to-complex FFTs (scipy.fft.fft, fft2).
  • Frequency arrays (KX, KY) are defined following standard spectral conventions.
  • Dealiasing is applied using a sharp cutoff filter at a fraction of the maximum frequency.
  • For second-order equations, initial acceleration is derived from the governing operator.
  • Symbolic analysis includes plotting of the symbol's real/imaginary/absolute values and dispersion relation.

See Also

setup_1D : Sets up internal variables for one-dimensional problems. setup_2D : Sets up internal variables for two-dimensional problems. initialize_conditions : Applies initial data and enforces compatibility. check_cfl_condition : Verifies time step against stability constraints. plot_symbol : Visualizes the linear operator’s symbol in frequency space. analyze_wave_propagation : Analyzes group velocity. interactive_symbol_analysis : Interactive tools for ψOp-based equations.

def setup_1D(self, Lx, Nx):
666    def setup_1D(self, Lx, Nx):
667        """
668        Configure internal variables for one-dimensional (1D) problems.
669    
670        This private method initializes spatial and frequency grids, applies dealiasing,
671        and prepares either pseudo-differential symbols or linear operators for use in time evolution.
672        
673        It assumes periodic boundary conditions and uses real-to-complex FFT conventions.
674        The spatial domain is centered at zero: [-Lx/2, Lx/2].
675    
676        Parameters
677        ----------
678        Lx : float
679            Physical size of the spatial domain along the x-axis.
680        Nx : int
681            Number of grid points in the x-direction.
682    
683        Attributes Set
684        --------------
685        - self.Lx : float
686            Size of the spatial domain.
687        - self.Nx : int
688            Number of spatial points.
689        - self.x_grid : np.ndarray
690            1D array of spatial coordinates.
691        - self.X : np.ndarray
692            Alias to `self.x_grid`, used in physical space computations.
693        - self.kx : np.ndarray
694            Array of wavenumbers corresponding to the Fourier transform.
695        - self.KX : np.ndarray
696            Alias to `self.kx`, used in frequency space computations.
697        - self.dealiasing_mask : np.ndarray
698            Boolean mask used to suppress aliased frequencies during nonlinear calculations.
699        - self.exp_L : np.ndarray
700            Exponential of the linear operator scaled by time step: exp(L(k) · dt).
701        - self.omega_val : np.ndarray
702            Frequency values ω(k) = Re[√(L(k))] used in second-order time stepping.
703        - self.cos_omega_dt, self.sin_omega_dt : np.ndarray
704            Cosine and sine of ω(k)·dt for dispersive propagation.
705        - self.inv_omega : np.ndarray
706            Inverse of ω(k), used to avoid division-by-zero in time stepping.
707    
708        Notes
709        -----
710        - Frequencies are computed using `scipy.fft.fftfreq` and then shifted to center zero frequency.
711        - Dealiasing is applied using a sharp cutoff filter based on `self.dealiasing_ratio`.
712        - If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via `prepare_symbol_tables`.
713        - For second-order equations, the dispersion relation ω(k) is extracted from the linear operator L(k).
714    
715        See Also
716        --------
717        setup_2D : Equivalent setup for two-dimensional problems.
718        prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation.
719        setup_omega_terms : Sets up terms involving ω(k) for second-order evolution.
720        """
721        self.Lx, self.Nx = Lx, Nx
722        self.x_grid = np.linspace(-Lx/2, Lx/2, Nx, endpoint=False)
723        self.X = self.x_grid
724        self.kx = 2 * np.pi * fftfreq(Nx, d=Lx / Nx)
725        self.KX = self.kx
726    
727        # Dealiasing mask
728        k_max = self.dealiasing_ratio * np.max(np.abs(self.kx))
729        self.dealiasing_mask = (np.abs(self.KX) <= k_max)
730    
731        # Preparation of symbol or linear operator
732        if self.has_psi:
733            self.prepare_symbol_tables()
734        else:
735            L_vals = np.array(self.L(self.KX), dtype=np.complex128)
736            self.exp_L = np.exp(L_vals * self.dt)
737            if self.temporal_order == 2:
738                omega_val = self.omega(self.KX)
739                self.setup_omega_terms(omega_val)

Configure internal variables for one-dimensional (1D) problems.

This private method initializes spatial and frequency grids, applies dealiasing, and prepares either pseudo-differential symbols or linear operators for use in time evolution.

It assumes periodic boundary conditions and uses real-to-complex FFT conventions. The spatial domain is centered at zero: [-Lx/2, Lx/2].

Parameters

Lx : float Physical size of the spatial domain along the x-axis. Nx : int Number of grid points in the x-direction.

Attributes Set

  • self.Lx : float Size of the spatial domain.
  • self.Nx : int Number of spatial points.
  • self.x_grid : np.ndarray 1D array of spatial coordinates.
  • self.X : np.ndarray Alias to self.x_grid, used in physical space computations.
  • self.kx : np.ndarray Array of wavenumbers corresponding to the Fourier transform.
  • self.KX : np.ndarray Alias to self.kx, used in frequency space computations.
  • self.dealiasing_mask : np.ndarray Boolean mask used to suppress aliased frequencies during nonlinear calculations.
  • self.exp_L : np.ndarray Exponential of the linear operator scaled by time step: exp(L(k) · dt).
  • self.omega_val : np.ndarray Frequency values ω(k) = Re[√(L(k))] used in second-order time stepping.
  • self.cos_omega_dt, self.sin_omega_dt : np.ndarray Cosine and sine of ω(k)·dt for dispersive propagation.
  • self.inv_omega : np.ndarray Inverse of ω(k), used to avoid division-by-zero in time stepping.

Notes

  • Frequencies are computed using scipy.fft.fftfreq and then shifted to center zero frequency.
  • Dealiasing is applied using a sharp cutoff filter based on self.dealiasing_ratio.
  • If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via prepare_symbol_tables.
  • For second-order equations, the dispersion relation ω(k) is extracted from the linear operator L(k).

See Also

setup_2D : Equivalent setup for two-dimensional problems. prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation. setup_omega_terms : Sets up terms involving ω(k) for second-order evolution.

def setup_2D(self, Lx, Ly, Nx, Ny):
741    def setup_2D(self, Lx, Ly, Nx, Ny):
742        """
743        Configure internal variables for two-dimensional (2D) problems.
744    
745        This private method initializes spatial and frequency grids, applies dealiasing,
746        and prepares either pseudo-differential symbols or linear operators for use in time evolution.
747        
748        It assumes periodic boundary conditions and uses real-to-complex FFT conventions.
749        The spatial domain is centered at zero: [-Lx/2, Lx/2] × [-Ly/2, Ly/2].
750    
751        Parameters
752        ----------
753        Lx : float
754            Physical size of the spatial domain along the x-axis.
755        Ly : float
756            Physical size of the spatial domain along the y-axis.
757        Nx : int
758            Number of grid points along the x-direction.
759        Ny : int
760            Number of grid points along the y-direction.
761    
762        Attributes Set
763        --------------
764        - self.Lx, self.Ly : float
765            Size of the spatial domain in each direction.
766        - self.Nx, self.Ny : int
767            Number of spatial points in each direction.
768        - self.x_grid, self.y_grid : np.ndarray
769            1D arrays of spatial coordinates in x and y directions.
770        - self.X, self.Y : np.ndarray
771            2D meshgrids of spatial coordinates for physical space computations.
772        - self.kx, self.ky : np.ndarray
773            Arrays of wavenumbers corresponding to Fourier transforms in x and y directions.
774        - self.KX, self.KY : np.ndarray
775            Meshgrids of wavenumbers used in frequency space computations.
776        - self.dealiasing_mask : np.ndarray
777            Boolean mask used to suppress aliased frequencies during nonlinear calculations.
778        - self.exp_L : np.ndarray
779            Exponential of the linear operator scaled by time step: exp(L(kx, ky) · dt).
780        - self.omega_val : np.ndarray
781            Frequency values ω(kx, ky) = Re[√(L(kx, ky))] used in second-order time stepping.
782        - self.cos_omega_dt, self.sin_omega_dt : np.ndarray
783            Cosine and sine of ω(kx, ky)·dt for dispersive propagation.
784        - self.inv_omega : np.ndarray
785            Inverse of ω(kx, ky), used to avoid division-by-zero in time stepping.
786    
787        Notes
788        -----
789        - Frequencies are computed using `scipy.fft.fftfreq` and then shifted to center zero frequency.
790        - Dealiasing is applied using a sharp cutoff filter based on `self.dealiasing_ratio`.
791        - If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via `prepare_symbol_tables`.
792        - For second-order equations, the dispersion relation ω(kx, ky) is extracted from the linear operator L(kx, ky).
793    
794        See Also
795        --------
796        setup_1D : Equivalent setup for one-dimensional problems.
797        prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation.
798        setup_omega_terms : Sets up terms involving ω(kx, ky) for second-order evolution.
799        """
800        self.Lx, self.Ly = Lx, Ly
801        self.Nx, self.Ny = Nx, Ny
802        self.x_grid = np.linspace(-Lx/2, Lx/2, Nx, endpoint=False)
803        self.y_grid = np.linspace(-Ly/2, Ly/2, Ny, endpoint=False)
804        self.X, self.Y = np.meshgrid(self.x_grid, self.y_grid, indexing='ij')
805        self.kx = 2 * np.pi * fftfreq(Nx, d=Lx / Nx)
806        self.ky = 2 * np.pi * fftfreq(Ny, d=Ly / Ny)
807        self.KX, self.KY = np.meshgrid(self.kx, self.ky, indexing='ij')
808    
809        # Dealiasing mask
810        kx_max = self.dealiasing_ratio * np.max(np.abs(self.kx))
811        ky_max = self.dealiasing_ratio * np.max(np.abs(self.ky))
812        self.dealiasing_mask = (np.abs(self.KX) <= kx_max) & (np.abs(self.KY) <= ky_max)
813    
814        # Preparation of symbol or linear operator
815        if self.has_psi:
816            self.prepare_symbol_tables()
817        else:
818            L_vals = self.L(self.KX, self.KY)
819            self.exp_L = np.exp(L_vals * self.dt)
820            if self.temporal_order == 2:
821                omega_val = self.omega(self.KX, self.KY)
822                self.setup_omega_terms(omega_val)

Configure internal variables for two-dimensional (2D) problems.

This private method initializes spatial and frequency grids, applies dealiasing, and prepares either pseudo-differential symbols or linear operators for use in time evolution.

It assumes periodic boundary conditions and uses real-to-complex FFT conventions. The spatial domain is centered at zero: [-Lx/2, Lx/2] × [-Ly/2, Ly/2].

Parameters

Lx : float Physical size of the spatial domain along the x-axis. Ly : float Physical size of the spatial domain along the y-axis. Nx : int Number of grid points along the x-direction. Ny : int Number of grid points along the y-direction.

Attributes Set

  • self.Lx, self.Ly : float Size of the spatial domain in each direction.
  • self.Nx, self.Ny : int Number of spatial points in each direction.
  • self.x_grid, self.y_grid : np.ndarray 1D arrays of spatial coordinates in x and y directions.
  • self.X, self.Y : np.ndarray 2D meshgrids of spatial coordinates for physical space computations.
  • self.kx, self.ky : np.ndarray Arrays of wavenumbers corresponding to Fourier transforms in x and y directions.
  • self.KX, self.KY : np.ndarray Meshgrids of wavenumbers used in frequency space computations.
  • self.dealiasing_mask : np.ndarray Boolean mask used to suppress aliased frequencies during nonlinear calculations.
  • self.exp_L : np.ndarray Exponential of the linear operator scaled by time step: exp(L(kx, ky) · dt).
  • self.omega_val : np.ndarray Frequency values ω(kx, ky) = Re[√(L(kx, ky))] used in second-order time stepping.
  • self.cos_omega_dt, self.sin_omega_dt : np.ndarray Cosine and sine of ω(kx, ky)·dt for dispersive propagation.
  • self.inv_omega : np.ndarray Inverse of ω(kx, ky), used to avoid division-by-zero in time stepping.

Notes

  • Frequencies are computed using scipy.fft.fftfreq and then shifted to center zero frequency.
  • Dealiasing is applied using a sharp cutoff filter based on self.dealiasing_ratio.
  • If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via prepare_symbol_tables.
  • For second-order equations, the dispersion relation ω(kx, ky) is extracted from the linear operator L(kx, ky).

See Also

setup_1D : Equivalent setup for one-dimensional problems. prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation. setup_omega_terms : Sets up terms involving ω(kx, ky) for second-order evolution.

def setup_omega_terms(self, omega_val):
824    def setup_omega_terms(self, omega_val):
825        """
826        Initialize terms derived from the angular frequency ω for time evolution.
827    
828        This private method precomputes and stores key trigonometric and inverse quantities
829        based on the dispersion relation ω(k), used in second-order time integration schemes.
830        
831        These values are essential for solving wave-like equations with dispersive behavior:
832            cos(ω·dt), sin(ω·dt), 1/ω
833        
834        The inverse frequency is computed safely to avoid division by zero.
835    
836        Parameters
837        ----------
838        omega_val : np.ndarray
839            Array of angular frequency values ω(k) evaluated at discrete wavenumbers.
840            Can be one-dimensional (1D) or two-dimensional (2D) depending on spatial dimension.
841    
842        Attributes Set
843        --------------
844        - self.omega_val : np.ndarray
845            Copy of the input angular frequency array.
846        - self.cos_omega_dt : np.ndarray
847            Cosine of ω(k) multiplied by time step: cos(ω(k) · dt).
848        - self.sin_omega_dt : np.ndarray
849            Sine of ω(k) multiplied by time step: sin(ω(k) · dt).
850        - self.inv_omega : np.ndarray
851            Inverse of ω(k), with zeros where ω(k) == 0 to avoid division by zero.
852    
853        Notes
854        -----
855        - This method is typically called during setup when solving second-order PDEs
856          involving dispersive waves (e.g., Klein-Gordon, Schrödinger, or water wave equations).
857        - The safe computation of 1/ω ensures numerical stability even when low frequencies are present.
858        - These precomputed arrays are used in spectral propagators for accurate time stepping.
859    
860        See Also
861        --------
862        setup_1D : Sets up internal variables for one-dimensional problems.
863        setup_2D : Sets up internal variables for two-dimensional problems.
864        solve : Time integration using the computed frequency terms.
865        """
866        self.omega_val = omega_val
867        self.cos_omega_dt = np.cos(omega_val * self.dt)
868        self.sin_omega_dt = np.sin(omega_val * self.dt)
869        self.inv_omega = np.zeros_like(omega_val)
870        nonzero = omega_val != 0
871        self.inv_omega[nonzero] = 1.0 / omega_val[nonzero]

Initialize terms derived from the angular frequency ω for time evolution.

This private method precomputes and stores key trigonometric and inverse quantities based on the dispersion relation ω(k), used in second-order time integration schemes.

These values are essential for solving wave-like equations with dispersive behavior: cos(ω·dt), sin(ω·dt), 1/ω

The inverse frequency is computed safely to avoid division by zero.

Parameters

omega_val : np.ndarray Array of angular frequency values ω(k) evaluated at discrete wavenumbers. Can be one-dimensional (1D) or two-dimensional (2D) depending on spatial dimension.

Attributes Set

  • self.omega_val : np.ndarray Copy of the input angular frequency array.
  • self.cos_omega_dt : np.ndarray Cosine of ω(k) multiplied by time step: cos(ω(k) · dt).
  • self.sin_omega_dt : np.ndarray Sine of ω(k) multiplied by time step: sin(ω(k) · dt).
  • self.inv_omega : np.ndarray Inverse of ω(k), with zeros where ω(k) == 0 to avoid division by zero.

Notes

  • This method is typically called during setup when solving second-order PDEs involving dispersive waves (e.g., Klein-Gordon, Schrödinger, or water wave equations).
  • The safe computation of 1/ω ensures numerical stability even when low frequencies are present.
  • These precomputed arrays are used in spectral propagators for accurate time stepping.

See Also

setup_1D : Sets up internal variables for one-dimensional problems. setup_2D : Sets up internal variables for two-dimensional problems. solve : Time integration using the computed frequency terms.

def evaluate_source_at_t0(self):
873    def evaluate_source_at_t0(self):
874        """
875        Evaluate source terms at initial time t = 0 over the spatial grid.
876    
877        This private method computes the total contribution of all source terms at the initial time,
878        evaluated across the entire spatial domain. It supports both one-dimensional (1D) and
879        two-dimensional (2D) configurations.
880    
881        Returns
882        -------
883        np.ndarray
884            A numpy array representing the evaluated source term at t=0:
885            - In 1D: Shape (Nx,), evaluated at each x in `self.x_grid`.
886            - In 2D: Shape (Nx, Ny), evaluated at each (x, y) pair in the grid.
887    
888        Notes
889        -----
890        - The symbolic expressions in `self.source_terms` are substituted with numerical values at t=0.
891        - In 1D, each term is evaluated at (t=0, x=x_val).
892        - In 2D, each term is evaluated at (t=0, x=x_val, y=y_val).
893        - Evaluated using SymPy's `evalf()` to ensure numeric conversion.
894        - This method assumes that the source terms have already been lambdified or are compatible with symbolic substitution.
895    
896        See Also
897        --------
898        setup : Initializes the spatial grid and source terms.
899        solve : Uses this evaluation during the first time step.
900        """
901        if self.dim == 1:
902            # Evaluation on the 1D spatial grid
903            return np.array([
904                sum(term.subs(self.t, 0).subs(self.x, x_val).evalf()
905                    for term in self.source_terms)
906                for x_val in self.x_grid
907            ], dtype=np.float64)
908        else:
909            # Evaluation on the 2D spatial grid
910            return np.array([
911                [sum(term.subs({self.t: 0, self.x: x_val, self.y: y_val}).evalf()
912                      for term in self.source_terms)
913                 for y_val in self.y_grid]
914                for x_val in self.x_grid
915            ], dtype=np.float64)

Evaluate source terms at initial time t = 0 over the spatial grid.

This private method computes the total contribution of all source terms at the initial time, evaluated across the entire spatial domain. It supports both one-dimensional (1D) and two-dimensional (2D) configurations.

Returns

np.ndarray A numpy array representing the evaluated source term at t=0: - In 1D: Shape (Nx,), evaluated at each x in self.x_grid. - In 2D: Shape (Nx, Ny), evaluated at each (x, y) pair in the grid.

Notes

  • The symbolic expressions in self.source_terms are substituted with numerical values at t=0.
  • In 1D, each term is evaluated at (t=0, x=x_val).
  • In 2D, each term is evaluated at (t=0, x=x_val, y=y_val).
  • Evaluated using SymPy's evalf() to ensure numeric conversion.
  • This method assumes that the source terms have already been lambdified or are compatible with symbolic substitution.

See Also

setup : Initializes the spatial grid and source terms. solve : Uses this evaluation during the first time step.

def initialize_conditions(self, initial_condition, initial_velocity):
917    def initialize_conditions(self, initial_condition, initial_velocity):
918        """
919        Initialize the solution and velocity fields at t = 0.
920    
921        This private method sets up the initial state of the solution `u_prev` and, if applicable,
922        the time derivative (velocity) `v_prev` for second-order evolution equations.
923        
924        For second-order equations, it also computes the backward-in-time value `u_prev2`
925        needed by the Leap-Frog method. The acceleration at t = 0 is computed from:
926            ∂ₜ²u = L(u) + N(u) + f(x, t=0)
927        where L is the linear operator, N is the nonlinear term, and f is the source term.
928    
929        Parameters
930        ----------
931        initial_condition : callable
932            Function returning the initial condition u(x, 0) or u(x, y, 0).
933        initial_velocity : callable or None
934            Function returning the initial velocity ∂ₜu(x, 0) or ∂ₜu(x, y, 0). Required for
935            second-order equations; ignored otherwise.
936    
937        Raises
938        ------
939        ValueError
940            If `initial_velocity` is not provided for second-order equations.
941    
942        Notes
943        -----
944        - Applies periodic boundary conditions after setting initial data.
945        - Stores a copy of the initial state in `self.frames` for visualization/output.
946        - In second-order systems, initializes `self.u_prev2` using a Taylor expansion:
947          u_prev2 = u_prev - dt * v_prev + 0.5 * dt² * (∂ₜ²u)
948    
949        See Also
950        --------
951        apply_boundary : Enforces periodic boundary conditions on the solution field.
952        psiOp_apply : Computes pseudo-differential operator action for acceleration.
953        linear_rhs : Evaluates linear part of the equation in Fourier space.
954        apply_nonlinear : Handles nonlinear terms with spectral differentiation.
955        evaluate_source_at_t0 : Evaluates source terms at the initial time.
956        """
957        # Initial condition
958        if self.dim == 1:
959            self.u_prev = initial_condition(self.X)
960        else:
961            self.u_prev = initial_condition(self.X, self.Y)
962        self.apply_boundary(self.u_prev)
963    
964        # Initial velocity (second order)
965        if self.temporal_order == 2:
966            if initial_velocity is None:
967                raise ValueError("Initial velocity is required for second-order equations.")
968            if self.dim == 1:
969                self.v_prev = initial_velocity(self.X)
970            else:
971                self.v_prev = initial_velocity(self.X, self.Y)
972            self.u0 = np.copy(self.u_prev)
973            self.v0 = np.copy(self.v_prev)
974    
975            # Calculation of u_prev2 (initial acceleration)
976            if not hasattr(self, 'u_prev2'):
977                if self.has_psi:
978                    acc0 = -self.apply_psiOp(self.u_prev)
979                else:
980                    acc0 = self.linear_rhs(self.u_prev, is_v=False)
981                rhs_nl = self.apply_nonlinear(self.u_prev, is_v=False)
982                acc0 += rhs_nl
983                if hasattr(self, 'source_terms') and self.source_terms:
984                    acc0 += self.evaluate_source_at_t0()
985                self.u_prev2 = self.u_prev - self.dt * self.v_prev + 0.5 * self.dt**2 * acc0
986    
987        self.frames = [self.u_prev.copy()]

Initialize the solution and velocity fields at t = 0.

This private method sets up the initial state of the solution u_prev and, if applicable, the time derivative (velocity) v_prev for second-order evolution equations.

For second-order equations, it also computes the backward-in-time value u_prev2 needed by the Leap-Frog method. The acceleration at t = 0 is computed from: ∂ₜ²u = L(u) + N(u) + f(x, t=0) where L is the linear operator, N is the nonlinear term, and f is the source term.

Parameters

initial_condition : callable Function returning the initial condition u(x, 0) or u(x, y, 0). initial_velocity : callable or None Function returning the initial velocity ∂ₜu(x, 0) or ∂ₜu(x, y, 0). Required for second-order equations; ignored otherwise.

Raises

ValueError If initial_velocity is not provided for second-order equations.

Notes

  • Applies periodic boundary conditions after setting initial data.
  • Stores a copy of the initial state in self.frames for visualization/output.
  • In second-order systems, initializes self.u_prev2 using a Taylor expansion: u_prev2 = u_prev - dt * v_prev + 0.5 * dt² * (∂ₜ²u)

See Also

apply_boundary : Enforces periodic boundary conditions on the solution field. psiOp_apply : Computes pseudo-differential operator action for acceleration. linear_rhs : Evaluates linear part of the equation in Fourier space. apply_nonlinear : Handles nonlinear terms with spectral differentiation. evaluate_source_at_t0 : Evaluates source terms at the initial time.

def apply_boundary(self, u):
 989    def apply_boundary(self, u):
 990        """
 991        Apply boundary conditions to the solution array based on the specified type.
 992    
 993        This method supports two types of boundary conditions:
 994        
 995        - 'periodic': Enforces periodicity by copying opposite boundary values.
 996        - 'dirichlet': Sets all boundary values to zero (homogeneous Dirichlet condition).
 997    
 998        Parameters
 999        ----------
1000        u : np.ndarray
1001            The solution array representing the field values on a spatial grid.
1002            In 1D, shape must be (Nx,). In 2D, shape must be (Nx, Ny).
1003    
1004        Raises
1005        ------
1006        ValueError
1007            If `self.boundary_condition` is not one of {'periodic', 'dirichlet'}.
1008    
1009        Notes
1010        -----
1011        - For 'periodic':
1012            * In 1D: u[0] = u[-2], u[-1] = u[1]
1013            * In 2D: First and last rows/columns are set equal to their neighbors.
1014        - For 'dirichlet':
1015            * All boundary points are explicitly set to zero.
1016        """
1017    
1018        if self.boundary_condition == 'periodic':
1019            if self.dim == 1:
1020                u[0] = u[-2]
1021                u[-1] = u[1]
1022            elif self.dim == 2:
1023                u[0, :] = u[-2, :]
1024                u[-1, :] = u[1, :]
1025                u[:, 0] = u[:, -2]
1026                u[:, -1] = u[:, 1]
1027    
1028        elif self.boundary_condition == 'dirichlet':
1029            if self.dim == 1:
1030                u[0] = 0
1031                u[-1] = 0
1032            elif self.dim == 2:
1033                u[0, :] = 0
1034                u[-1, :] = 0
1035                u[:, 0] = 0
1036                u[:, -1] = 0
1037    
1038        else:
1039            raise ValueError(
1040                f"Invalid boundary condition '{self.boundary_condition}'. "
1041                "Supported types are 'periodic' and 'dirichlet'."
1042            )

Apply boundary conditions to the solution array based on the specified type.

This method supports two types of boundary conditions:

  • 'periodic': Enforces periodicity by copying opposite boundary values.
  • 'dirichlet': Sets all boundary values to zero (homogeneous Dirichlet condition).

Parameters

u : np.ndarray The solution array representing the field values on a spatial grid. In 1D, shape must be (Nx,). In 2D, shape must be (Nx, Ny).

Raises

ValueError If self.boundary_condition is not one of {'periodic', 'dirichlet'}.

Notes

  • For 'periodic':
    • In 1D: u[0] = u[-2], u[-1] = u[1]
    • In 2D: First and last rows/columns are set equal to their neighbors.
  • For 'dirichlet':
    • All boundary points are explicitly set to zero.
def apply_nonlinear(self, u, is_v=False):
1044    def apply_nonlinear(self, u, is_v=False):
1045        """
1046        Apply nonlinear terms to the solution using spectral differentiation with dealiasing.
1047
1048        This method evaluates all nonlinear terms present in the PDE by substituting spatial 
1049        derivatives with their spectral approximations computed via FFT. The dealiasing mask 
1050        ensures numerical stability by removing high-frequency components that could lead 
1051        to aliasing errors.
1052
1053        Parameters
1054        ----------
1055        u : numpy.ndarray
1056            Current solution array on the spatial grid.
1057        is_v : bool
1058            If True, evaluates nonlinear terms for the velocity field v instead of u.
1059
1060        Returns:
1061            numpy.ndarray: Array representing the contribution of nonlinear terms multiplied by dt.
1062
1063        Notes:
1064        
1065        - In 1D, computes ∂ₓu via FFT and substitutes any derivative term in the nonlinear expressions.
1066        - In 2D, computes ∂ₓu and ∂ᵧu via FFT and performs similar substitutions.
1067        - Uses lambdify to evaluate symbolic nonlinear expressions numerically.
1068        - Derivatives are replaced symbolically with 'u_x' and 'u_y' before evaluation.
1069        """
1070        if not self.nonlinear_terms:
1071            return np.zeros_like(u, dtype=np.complex128)
1072        
1073        nonlinear_term = np.zeros_like(u, dtype=np.complex128)
1074    
1075        if self.dim == 1:
1076            u_hat = self.fft(u)
1077            u_hat *= self.dealiasing_mask
1078            u = self.ifft(u_hat)
1079    
1080            u_x_hat = (1j * self.KX) * u_hat
1081            u_x = self.ifft(u_x_hat)
1082    
1083            for term in self.nonlinear_terms:
1084                term_replaced = term
1085                if term.has(Derivative):
1086                    for deriv in term.atoms(Derivative):
1087                        if deriv.args[1][0] == self.x:
1088                            term_replaced = term_replaced.subs(deriv, symbols('u_x'))            
1089                term_func = lambdify((self.t, self.x, self.u_eq, 'u_x'), term_replaced, 'numpy')
1090                if is_v:
1091                    nonlinear_term += term_func(0, self.X, self.v_prev, u_x)
1092                else:
1093                    nonlinear_term += term_func(0, self.X, u, u_x)
1094    
1095        elif self.dim == 2:
1096            u_hat = self.fft(u)
1097            u_hat *= self.dealiasing_mask
1098            u = self.ifft(u_hat)
1099    
1100            u_x_hat = (1j * self.KX) * u_hat
1101            u_y_hat = (1j * self.KY) * u_hat
1102            u_x = self.ifft(u_x_hat)
1103            u_y = self.ifft(u_y_hat)
1104    
1105            for term in self.nonlinear_terms:
1106                term_replaced = term
1107                if term.has(Derivative):
1108                    for deriv in term.atoms(Derivative):
1109                        if deriv.args[1][0] == self.x:
1110                            term_replaced = term_replaced.subs(deriv, symbols('u_x'))
1111                        elif deriv.args[1][0] == self.y:
1112                            term_replaced = term_replaced.subs(deriv, symbols('u_y'))
1113                term_func = lambdify((self.t, self.x, self.y, self.u_eq, 'u_x', 'u_y'), term_replaced, 'numpy')
1114                if is_v:
1115                    nonlinear_term += term_func(0, self.X, self.Y, self.v_prev, u_x, u_y)
1116                else:
1117                    nonlinear_term += term_func(0, self.X, self.Y, u, u_x, u_y)
1118        else:
1119            raise ValueError("Unsupported spatial dimension.")
1120        
1121        return nonlinear_term * self.dt

Apply nonlinear terms to the solution using spectral differentiation with dealiasing.

This method evaluates all nonlinear terms present in the PDE by substituting spatial derivatives with their spectral approximations computed via FFT. The dealiasing mask ensures numerical stability by removing high-frequency components that could lead to aliasing errors.

Parameters

u : numpy.ndarray Current solution array on the spatial grid. is_v : bool If True, evaluates nonlinear terms for the velocity field v instead of u.

Returns: numpy.ndarray: Array representing the contribution of nonlinear terms multiplied by dt.

Notes:

  • In 1D, computes ∂ₓu via FFT and substitutes any derivative term in the nonlinear expressions.
  • In 2D, computes ∂ₓu and ∂ᵧu via FFT and performs similar substitutions.
  • Uses lambdify to evaluate symbolic nonlinear expressions numerically.
  • Derivatives are replaced symbolically with 'u_x' and 'u_y' before evaluation.
def prepare_symbol_tables(self):
1123    def prepare_symbol_tables(self):
1124        """
1125        Precompute and store evaluated pseudo-differential operator symbols for spectral methods.
1126
1127        This method evaluates all pseudo-differential operators (ψOp) present in the PDE
1128        over the spatial and frequency grids, scales them by their respective coefficients,
1129        and combines them into a single composite symbol used in time-stepping and inversion.
1130
1131        The evaluation is performed via the `evaluate` method of each PseudoDifferentialOperator,
1132        which computes p(x, ξ) or p(x, y, ξ, η) numerically over the current grid configuration.
1133
1134        Side Effects:
1135            self.precomputed_symbols : list of (coeff, symbol_array)
1136                Each tuple contains a coefficient and its evaluated symbol on the grid.
1137            self.combined_symbol : np.ndarray
1138                Sum of all scaled symbol arrays: ∑(coeffₖ * ψₖ(x, ξ))
1139
1140        Raises:
1141            ValueError: If the spatial dimension is not 1D or 2D.
1142        """
1143        self.precomputed_symbols = []
1144        self.combined_symbol = 0
1145        for coeff, psi in self.psi_ops:
1146            if self.dim == 1:
1147                raw = psi.evaluate(self.X, None, self.KX, None)
1148            elif self.dim == 2:
1149                raw = psi.evaluate(self.X, self.Y, self.KX, self.KY)
1150            else:
1151                raise ValueError('Unsupported spatial dimension.')
1152            raw_flat = raw.flatten()
1153            converted = np.array([complex(N(val)) for val in raw_flat], dtype=np.complex128)
1154            raw_eval = converted.reshape(raw.shape)
1155            self.precomputed_symbols.append((coeff, raw_eval))
1156        self.combined_symbol = sum((coeff * sym for coeff, sym in self.precomputed_symbols))
1157        self.combined_symbol = np.array(self.combined_symbol, dtype=np.complex128)

Precompute and store evaluated pseudo-differential operator symbols for spectral methods.

This method evaluates all pseudo-differential operators (ψOp) present in the PDE over the spatial and frequency grids, scales them by their respective coefficients, and combines them into a single composite symbol used in time-stepping and inversion.

The evaluation is performed via the evaluate method of each PseudoDifferentialOperator, which computes p(x, ξ) or p(x, y, ξ, η) numerically over the current grid configuration.

Side Effects: self.precomputed_symbols : list of (coeff, symbol_array) Each tuple contains a coefficient and its evaluated symbol on the grid. self.combined_symbol : np.ndarray Sum of all scaled symbol arrays: ∑(coeffₖ * ψₖ(x, ξ))

Raises: ValueError: If the spatial dimension is not 1D or 2D.

def total_symbol_expr(self):
1159    def total_symbol_expr(self):
1160        """
1161        Compute the total pseudo-differential symbol expression from all pseudo_terms.
1162
1163        This method constructs the full symbol of the pseudo-differential operator
1164        by summing up all coefficient-weighted symbolic expressions.
1165
1166        The result is cached in self.symbol_expr to avoid recomputation.
1167
1168        Returns:
1169            sympy.Expr: The combined symbol expression, representing the full
1170                        pseudo-differential operator in symbolic form.
1171
1172        Example:
1173            Given pseudo_terms = [(2, ξ²), (1, x·ξ)], this returns 2·ξ² + x·ξ.
1174        """
1175        if not hasattr(self, '_symbol_expr'):
1176            self.symbol_expr = sum(coeff * expr for coeff, expr in self.pseudo_terms)
1177        return self.symbol_expr

Compute the total pseudo-differential symbol expression from all pseudo_terms.

This method constructs the full symbol of the pseudo-differential operator by summing up all coefficient-weighted symbolic expressions.

The result is cached in self.symbol_expr to avoid recomputation.

Returns: sympy.Expr: The combined symbol expression, representing the full pseudo-differential operator in symbolic form.

Example: Given pseudo_terms = [(2, ξ²), (1, x·ξ)], this returns 2·ξ² + x·ξ.

def build_symbol_func(self, expr):
1179    def build_symbol_func(self, expr):
1180        """
1181        Build a numerical evaluation function from a symbolic pseudo-differential operator expression.
1182    
1183        This method converts a symbolic expression representing a pseudo-differential operator into
1184        a callable NumPy-compatible function. The function accepts spatial and frequency variables
1185        depending on the dimensionality of the problem.
1186    
1187        Parameters
1188        ----------
1189        expr : sympy expression
1190            A SymPy expression representing the symbol of the pseudo-differential operator. It may depend on spatial variables (x, y) and frequency variables (xi, eta).
1191    
1192        Returns:
1193            function : A lambdified function that takes:
1194            
1195                - In 1D: `(x, xi)` — spatial coordinate and frequency.
1196                - In 2D: `(x, y, xi, eta)` — spatial coordinates and frequencies.
1197                
1198              Returns a NumPy array of evaluated symbol values over input grids.
1199    
1200        Notes:
1201            - Uses `lambdify` from SymPy with the `'numpy'` backend for efficient vectorized evaluation.
1202            - Real variable assumptions are enforced to ensure proper behavior in numerical contexts.
1203            - Used internally by methods like `apply_psiOp`, `evaluate`, and visualization tools.
1204        """
1205        if self.dim == 1:
1206            x, xi = symbols('x xi', real=True)
1207            return lambdify((x, xi), expr, 'numpy')
1208        else:
1209            x, y, xi, eta = symbols('x y xi eta', real=True)
1210            return lambdify((x, y, xi, eta), expr, 'numpy')

Build a numerical evaluation function from a symbolic pseudo-differential operator expression.

This method converts a symbolic expression representing a pseudo-differential operator into a callable NumPy-compatible function. The function accepts spatial and frequency variables depending on the dimensionality of the problem.

Parameters

expr : sympy expression A SymPy expression representing the symbol of the pseudo-differential operator. It may depend on spatial variables (x, y) and frequency variables (xi, eta).

Returns: function : A lambdified function that takes:

    - In 1D: `(x, xi)` — spatial coordinate and frequency.
    - In 2D: `(x, y, xi, eta)` — spatial coordinates and frequencies.

  Returns a NumPy array of evaluated symbol values over input grids.

Notes: - Uses lambdify from SymPy with the 'numpy' backend for efficient vectorized evaluation. - Real variable assumptions are enforced to ensure proper behavior in numerical contexts. - Used internally by methods like apply_psiOp, evaluate, and visualization tools.

def apply_psiOp(self, u):
1212    def apply_psiOp(self, u):
1213        """
1214        Apply the pseudo-differential operator to the input field u.
1215    
1216        This method dispatches the application of the pseudo-differential operator based on:
1217        
1218        - Whether the symbol is spatially dependent (x/y)
1219        - The boundary condition in use (periodic or dirichlet)
1220    
1221        Supported operations:
1222        
1223        - Constant-coefficient symbols: applied via Fourier multiplication.
1224        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
1225        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
1226    
1227        Dispatch Logic:\n
1228        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
1229        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
1230        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
1231        
1232        This method delegates to the apply() method of each 
1233        PseudoDifferentialOperator instance.
1234        
1235        Parameters
1236        ----------
1237        u : ndarray
1238            Function to which operators are applied
1239            
1240        Returns
1241        -------
1242        ndarray
1243            Result of applying all operators with their coefficients
1244        """
1245        if not hasattr(self, 'psi_ops') or not self.psi_ops:
1246            raise ValueError("No pseudo-differential operators defined")
1247        
1248        result = np.zeros_like(u, dtype=np.complex128)
1249        
1250        for coeff, psi_op in self.psi_ops:
1251            coeff = np.complex128(coeff)
1252            if self.dim == 1:
1253                contribution = psi_op.apply(
1254                    u=u,
1255                    x_grid=self.x_grid,
1256                    kx=self.kx,
1257                    boundary_condition=self.boundary_condition,
1258                    dealiasing_mask=self.dealiasing_mask
1259                )
1260            elif self.dim == 2:
1261                contribution = psi_op.apply(
1262                    u=u,
1263                    x_grid=self.x_grid,
1264                    kx=self.kx,
1265                    y_grid=self.y_grid,
1266                    ky=self.ky,
1267                    boundary_condition=self.boundary_condition,
1268                    dealiasing_mask=self.dealiasing_mask
1269                )
1270            else:
1271                raise ValueError("Only 1D and 2D supported")
1272            
1273            result += coeff * contribution
1274        
1275        return result

Apply the pseudo-differential operator to the input field u.

This method dispatches the application of the pseudo-differential operator based on:

  • Whether the symbol is spatially dependent (x/y)
  • The boundary condition in use (periodic or dirichlet)

Supported operations:

  • Constant-coefficient symbols: applied via Fourier multiplication.
  • Spatially varying symbols: applied via Kohn–Nirenberg quantization.
  • Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.

Dispatch Logic:

if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]

elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)

elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)

This method delegates to the apply() method of each PseudoDifferentialOperator instance.

Parameters

u : ndarray Function to which operators are applied

Returns

ndarray Result of applying all operators with their coefficients

def step_order1_with_psi(self, source_contribution):
1277    def step_order1_with_psi(self, source_contribution):
1278        """
1279        Perform one time step of a first-order evolution using a pseudo-differential operator.
1280    
1281        This method updates the solution field using an exponential integrator or explicit Euler scheme,
1282        depending on boundary conditions and the structure of the pseudo-differential symbol.
1283        It supports:
1284        - Linear dynamics via pseudo-differential operator L (possibly nonlocal)
1285        - Nonlinear terms computed via spectral differentiation
1286        - External source contributions
1287    
1288        The update follows **three distinct computational paths**:
1289    
1290        1. **Periodic boundaries + diagonalizable symbol**  
1291           Symbol is constant in space → use direct Fourier-based exponential integrator:  
1292               uₙ₊₁ = e⁻ᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(−LΔt) ⋅ (N(uₙ) + F)
1293    
1294        2. **Non-diagonalizable but spatially uniform symbol**  
1295           General exponential time differencing of order 1:  
1296               uₙ₊₁ = eᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(LΔt) ⋅ (N(uₙ) + F)
1297    
1298        3. **Spatially varying symbol**  
1299           No frequency diagonalization available → use explicit Euler:  
1300               uₙ₊₁ = uₙ + Δt ⋅ (L(uₙ) + N(uₙ) + F)
1301    
1302        where:
1303            L(uₙ) = linear part via pseudo-differential operator
1304            N(uₙ) = nonlinear contribution at current time step
1305            F     = external source term
1306            Δt    = time step size
1307            φ₁(z) = (eᶻ − 1)/z (with safe handling near z=0)
1308    
1309        Boundary conditions are applied after each update to ensure consistency.
1310    
1311        Parameters
1312            source_contribution (np.ndarray): Array representing the external source term at current time step.
1313                                              Must match the spatial dimensions of self.u_prev.
1314    
1315        Returns:
1316            np.ndarray: Updated solution array after one time step.
1317        """
1318        # Handling null source
1319        if np.isscalar(source_contribution):
1320            source = np.zeros_like(self.u_prev)
1321        else:
1322            source = source_contribution
1323
1324        def spectral_filter(u, cutoff=0.8):
1325            if u.ndim == 1:
1326                u_hat = self.fft(u)
1327                N = len(u)
1328                k = fftfreq(N)
1329                mask = np.exp(-(k / cutoff)**8)
1330                return self.ifft(u_hat * mask).real
1331            elif u.ndim == 2:
1332                u_hat = self.fft(u)
1333                Ny, Nx = u.shape
1334                ky = fftfreq(Ny)[:, None]
1335                kx = fftfreq(Nx)[None, :]
1336                k_squared = kx**2 + ky**2
1337                mask = np.exp(-(np.sqrt(k_squared) / cutoff)**8)
1338                return self.ifft(u_hat * mask).real
1339            else:
1340                raise ValueError("Only 1D and 2D arrays are supported.")
1341
1342        # Recalculate symbol if necessary
1343        if self.is_spatial:
1344            self.prepare_symbol_tables()  # Recalculates self.combined_symbol
1345    
1346        # Case with FFT (symbol diagonalizable in Fourier space)
1347        if self.boundary_condition == 'periodic' and not self.is_spatial:
1348            u_hat = self.fft(self.u_prev)
1349            u_hat *= np.exp(-self.dt * self.combined_symbol)
1350            u_hat *= self.dealiasing_mask
1351            u_symb = self.ifft(u_hat)
1352            u_nl = self.apply_nonlinear(self.u_prev)
1353            u_new = u_symb + u_nl + source
1354        else:
1355            if not self.is_spatial:
1356                # General case with ETD1
1357                u_nl = self.apply_nonlinear(self.u_prev)
1358    
1359                # Calculation of exp(dt * L) and phi1(dt * L)
1360                L_vals = self.combined_symbol  # Uses the updated symbol
1361                exp_L = np.exp(-self.dt * L_vals)
1362                phi1_L = (exp_L - 1.0) / (self.dt * L_vals)
1363                phi1_L[np.isnan(phi1_L)] = 1.0  # Handling division by zero
1364    
1365                # Fourier transform
1366                u_hat = self.fft(self.u_prev)
1367                u_nl_hat = self.fft(u_nl)
1368                source_hat = self.fft(source)
1369    
1370                # Assembling the solution in Fourier space
1371                u_hat_new = exp_L * u_hat + self.dt * phi1_L * (u_nl_hat + source_hat)
1372                u_new = self.ifft(u_hat_new)
1373            else:
1374                # if the symbol depends on spatial variables : Euler method
1375                Lu_prev = -self.apply_psiOp(self.u_prev)
1376                u_nl = self.apply_nonlinear(self.u_prev)
1377                u_new = self.u_prev + self.dt * (Lu_prev + u_nl + source)
1378                u_new = spectral_filter(u_new, cutoff=self.dealiasing_ratio)
1379        # Applying boundary conditions
1380        self.apply_boundary(u_new)
1381        return u_new

Perform one time step of a first-order evolution using a pseudo-differential operator.

This method updates the solution field using an exponential integrator or explicit Euler scheme, depending on boundary conditions and the structure of the pseudo-differential symbol. It supports:

  • Linear dynamics via pseudo-differential operator L (possibly nonlocal)
  • Nonlinear terms computed via spectral differentiation
  • External source contributions

The update follows three distinct computational paths:

  1. Periodic boundaries + diagonalizable symbol
    Symbol is constant in space → use direct Fourier-based exponential integrator:
    uₙ₊₁ = e⁻ᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(−LΔt) ⋅ (N(uₙ) + F)

  2. Non-diagonalizable but spatially uniform symbol
    General exponential time differencing of order 1:
    uₙ₊₁ = eᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(LΔt) ⋅ (N(uₙ) + F)

  3. Spatially varying symbol
    No frequency diagonalization available → use explicit Euler:
    uₙ₊₁ = uₙ + Δt ⋅ (L(uₙ) + N(uₙ) + F)

where: L(uₙ) = linear part via pseudo-differential operator N(uₙ) = nonlinear contribution at current time step F = external source term Δt = time step size φ₁(z) = (eᶻ − 1)/z (with safe handling near z=0)

Boundary conditions are applied after each update to ensure consistency.

Parameters source_contribution (np.ndarray): Array representing the external source term at current time step. Must match the spatial dimensions of self.u_prev.

Returns: np.ndarray: Updated solution array after one time step.

def step_order2_with_psi(self, source_contribution):
1383    def step_order2_with_psi(self, source_contribution):
1384        """
1385        Perform one time step of a second-order time evolution using a pseudo-differential operator.
1386    
1387        This method updates the solution field using a second-order accurate scheme suitable for wave-like equations.
1388        The update includes contributions from:
1389        - Linear dynamics via a pseudo-differential operator (e.g., dispersion or stiffness)
1390        - Nonlinear terms computed via spectral differentiation
1391        - External source contributions
1392    
1393        Discretization follows a leapfrog-style finite difference in time:
1394        
1395            uₙ₊₁ = 2uₙ − uₙ₋₁ + Δt² ⋅ (L(uₙ) + N(uₙ) + F)
1396    
1397        where:
1398            L(uₙ) = linear part evaluated via pseudo-differential operator
1399            N(uₙ) = nonlinear contribution at current time step
1400            F     = external source term at current time step
1401            Δt    = time step size
1402    
1403        Boundary conditions are applied after each update to ensure consistency.
1404    
1405        Parameters
1406            source_contribution (np.ndarray): Array representing the external source term at current time step.
1407                                              Must match the spatial dimensions of self.u_prev.
1408    
1409        Returns:
1410            np.ndarray: Updated solution array after one time step.
1411        """
1412        Lu_prev = -self.apply_psiOp(self.u_prev)
1413        rhs_nl = self.apply_nonlinear(self.u_prev, is_v=False)
1414        u_new = 2 * self.u_prev - self.u_prev2 + self.dt ** 2 * (Lu_prev + rhs_nl + source_contribution)
1415        self.apply_boundary(u_new)
1416        self.u_prev2 = self.u_prev
1417        self.u_prev = u_new
1418        self.u = u_new
1419        return u_new

Perform one time step of a second-order time evolution using a pseudo-differential operator.

This method updates the solution field using a second-order accurate scheme suitable for wave-like equations. The update includes contributions from:

  • Linear dynamics via a pseudo-differential operator (e.g., dispersion or stiffness)
  • Nonlinear terms computed via spectral differentiation
  • External source contributions

Discretization follows a leapfrog-style finite difference in time:

uₙ₊₁ = 2uₙ − uₙ₋₁ + Δt² ⋅ (L(uₙ) + N(uₙ) + F)

where: L(uₙ) = linear part evaluated via pseudo-differential operator N(uₙ) = nonlinear contribution at current time step F = external source term at current time step Δt = time step size

Boundary conditions are applied after each update to ensure consistency.

Parameters source_contribution (np.ndarray): Array representing the external source term at current time step. Must match the spatial dimensions of self.u_prev.

Returns: np.ndarray: Updated solution array after one time step.

def solve(self):
1421    def solve(self):
1422        """
1423        Solve the partial differential equation numerically using spectral methods.
1424        
1425        This method evolves the solution in time using a combination of:
1426        - Fourier-based linear evolution (with dealiasing)
1427        - Nonlinear term handling via pseudo-spectral evaluation
1428        - Support for pseudo-differential operators (psiOp)
1429        - Source terms and boundary conditions
1430        
1431        The solver supports:
1432        - 1D and 2D spatial domains
1433        - First and second-order time evolution
1434        - Periodic and Dirichlet boundary conditions
1435        - Time-stepping schemes: default, ETD-RK4
1436        
1437        Returns:
1438            list[np.ndarray]: A list of solution arrays at each saved time frame.
1439        
1440        Side Effects:
1441            - Updates self.frames: stores solution snapshots
1442            - Updates self.energy_history: records total energy if enabled
1443            
1444        Algorithm Overview:
1445            For each time step:
1446                1. Evaluate source contributions (if any)
1447                2. Apply time evolution:
1448                    - Order 1:
1449                        - With psiOp: uses step_order1_with_psi
1450                        - With ETD-RK4: exponential time differencing
1451                        - Default: linear + nonlinear update
1452                    - Order 2:
1453                        - With psiOp: uses step_order2_with_psi
1454                        - With ETD-RK4: second-order exponential scheme
1455                        - Default: second-order leapfrog-style update
1456                3. Enforce boundary conditions
1457                4. Save solution snapshot periodically
1458                5. Record energy (for second-order systems without psiOp)
1459        """
1460        print('\n*******************')
1461        print('* Solving the PDE *')
1462        print('*******************\n')
1463        save_interval = max(1, self.Nt // self.n_frames)
1464        self.energy_history = []
1465        for step in range(self.Nt):
1466            if hasattr(self, 'source_terms') and self.source_terms:
1467                source_contribution = np.zeros_like(self.X, dtype=np.float64)
1468                for term in self.source_terms:
1469                    try:
1470                        if self.dim == 1:
1471                            source_func = lambdify((self.t, self.x), term, 'numpy')
1472                            source_contribution += source_func(step * self.dt, self.X)
1473                        elif self.dim == 2:
1474                            source_func = lambdify((self.t, self.x, self.y), term, 'numpy')
1475                            source_contribution += source_func(step * self.dt, self.X, self.Y)
1476                    except Exception as e:
1477                        print(f'Error evaluating source term {term}: {e}')
1478            else:
1479                source_contribution = 0
1480
1481            if self.temporal_order == 1:
1482                if self.has_psi:
1483                    u_new = self.step_order1_with_psi(source_contribution)
1484                elif hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1485                    u_new = self.step_ETD_RK4(self.u_prev)
1486                else:
1487                    u_hat = self.fft(self.u_prev)
1488                    u_hat *= self.exp_L
1489                    u_hat *= self.dealiasing_mask
1490                    u_lin = self.ifft(u_hat)
1491                    u_nl = self.apply_nonlinear(u_lin)
1492                    u_new = u_lin + u_nl + source_contribution
1493                self.apply_boundary(u_new)
1494                self.u_prev = u_new
1495
1496            elif self.temporal_order == 2:
1497                if self.has_psi:
1498                    u_new = self.step_order2_with_psi(source_contribution)
1499                else:
1500                    if hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1501                        u_new, v_new = self.step_ETD_RK4_order2(self.u_prev, self.v_prev)
1502                    else:
1503                        u_hat = self.fft(self.u_prev)
1504                        v_hat = self.fft(self.v_prev)
1505                        u_new_hat = self.cos_omega_dt * u_hat + self.sin_omega_dt * self.inv_omega * v_hat
1506                        v_new_hat = -self.omega_val * self.sin_omega_dt * u_hat + self.cos_omega_dt * v_hat
1507                        u_new = self.ifft(u_new_hat)
1508                        v_new = self.ifft(v_new_hat)
1509                        u_nl = self.apply_nonlinear(self.u_prev, is_v=False)
1510                        v_nl = self.apply_nonlinear(self.v_prev, is_v=True)
1511                        u_new += (u_nl + source_contribution) * self.dt ** 2 / 2
1512                        v_new += (u_nl + source_contribution) * self.dt
1513                    self.apply_boundary(u_new)
1514                    self.apply_boundary(v_new)
1515                    self.u_prev = u_new
1516                    self.v_prev = v_new
1517
1518            if step % save_interval == 0:
1519                self.frames.append(self.u_prev.copy())
1520
1521            if self.temporal_order == 2 and (not self.has_psi):
1522                E = self.compute_energy()
1523                self.energy_history.append(E)
1524
1525        return self.frames  

Solve the partial differential equation numerically using spectral methods.

This method evolves the solution in time using a combination of:

  • Fourier-based linear evolution (with dealiasing)
  • Nonlinear term handling via pseudo-spectral evaluation
  • Support for pseudo-differential operators (psiOp)
  • Source terms and boundary conditions

The solver supports:

  • 1D and 2D spatial domains
  • First and second-order time evolution
  • Periodic and Dirichlet boundary conditions
  • Time-stepping schemes: default, ETD-RK4

Returns: list[np.ndarray]: A list of solution arrays at each saved time frame.

Side Effects: - Updates self.frames: stores solution snapshots - Updates self.energy_history: records total energy if enabled

Algorithm Overview: For each time step: 1. Evaluate source contributions (if any) 2. Apply time evolution: - Order 1: - With psiOp: uses step_order1_with_psi - With ETD-RK4: exponential time differencing - Default: linear + nonlinear update - Order 2: - With psiOp: uses step_order2_with_psi - With ETD-RK4: second-order exponential scheme - Default: second-order leapfrog-style update 3. Enforce boundary conditions 4. Save solution snapshot periodically 5. Record energy (for second-order systems without psiOp)

def solve_stationary_psiOp(self, order=3):
1527    def solve_stationary_psiOp(self, order=3):
1528        """
1529        Solve stationary pseudo-differential equations of the form P[u] = f(x) or P[u] = f(x,y) using asymptotic inversion.
1530    
1531        This method computes the solution to a stationary (time-independent) pseudo-differential equation
1532        where the operator P is defined via symbolic expressions (psiOp). It constructs an asymptotic right inverse R 
1533        such that P∘R ≈ Id, then applies it to the source term f using either direct Fourier multiplication 
1534        (when the symbol is spatially independent) or Kohn–Nirenberg quantization (when spatial dependence is present).
1535    
1536        The inversion is based on the principal symbol of the operator and its asymptotic expansion up to the given order.
1537        Ellipticity of the symbol is checked numerically before inversion to ensure well-posedness.
1538    
1539        Parameters
1540        ----------
1541        order : int, default=3
1542            Order of the asymptotic expansion used to construct the right inverse of the pseudo-differential operator.
1543        method : str, optional
1544            Inversion strategy:
1545            - 'diagonal' (default): Fast approximate inversion using diagonal operators in frequency space.
1546            - 'full'                : Pointwise exact inversion (slower but more accurate).
1547    
1548        Returns
1549        -------
1550        ndarray
1551            The computed solution u(x) in 1D or u(x, y) in 2D as a NumPy array over the spatial grid.
1552    
1553        Raises
1554        ------
1555        ValueError
1556            If no pseudo-differential operator (psiOp) is defined.
1557            If linear or nonlinear terms other than psiOp are present.
1558            If the symbol is not elliptic on the grid.
1559            If no source term is provided for the right-hand side.
1560    
1561        Notes
1562        -----
1563        - The method assumes the problem is fully stationary: time derivatives must be absent.
1564        - Requires the equation to be purely pseudo-differential (no Op, Derivative, or nonlinear terms).
1565        - Symbol evaluation and inversion are dimension-aware (supports both 1D and 2D problems).
1566        - Supports optimization paths when the symbol does not depend on spatial variables.
1567    
1568        See Also
1569        --------
1570        right_inverse_asymptotic : Constructs the asymptotic inverse of the pseudo-differential operator.
1571        kohn_nirenberg           : Numerical implementation of general pseudo-differential operators.
1572        is_elliptic_numerically  : Verifies numerical ellipticity of the symbol.
1573        """
1574
1575        print("\n*******************************")
1576        print("* Solving the stationnary PDE *")
1577        print("*******************************\n")
1578        print("boundary condition: ",self.boundary_condition)
1579        
1580
1581        if not self.has_psi:
1582            raise ValueError("Only supports problems with psiOp.")
1583    
1584        if self.linear_terms or self.nonlinear_terms:
1585            raise ValueError("Stationary psiOp problems must be linear and purely pseudo-differential.")
1586
1587        if self.boundary_condition not in ('periodic', 'dirichlet'):
1588            raise ValueError(
1589                "For stationary PDEs, boundary conditions must be explicitly defined. "
1590                "Supported types are 'periodic' and 'dirichlet'."
1591            )    
1592            
1593        if self.dim == 1:
1594            x = self.x
1595            xi = symbols('xi', real=True)
1596            spatial_vars = (x,)
1597            freq_vars = (xi,)
1598            X, KX = self.X, self.KX
1599        elif self.dim == 2:
1600            x, y = self.x, self.y
1601            xi, eta = symbols('xi eta', real=True)
1602            spatial_vars = (x, y)
1603            freq_vars = (xi, eta)
1604            X, Y, KX, KY = self.X, self.Y, self.KX, self.KY
1605        else:
1606            raise ValueError("Unsupported spatial dimension.")
1607    
1608        total_symbol = sum(coeff * psi.expr for coeff, psi in self.psi_ops)
1609        psi_total = PseudoDifferentialOperator(total_symbol, spatial_vars, mode='symbol')
1610    
1611        # Check ellipticity
1612        if self.dim == 1:
1613            is_elliptic = psi_total.is_elliptic_numerically(X, KX)
1614        else:
1615            is_elliptic = psi_total.is_elliptic_numerically((X[:, 0], Y[0, :]), (KX[:, 0], KY[0, :]))
1616        if not is_elliptic:
1617            raise ValueError("❌ The pseudo-differential symbol is not numerically elliptic on the grid.")
1618        print("✅ Elliptic pseudo-differential symbol: inversion allowed.")
1619    
1620        R_symbol = psi_total.right_inverse_asymptotic(order=order)
1621        print('Right inverse asymptotic symbol:')
1622        pprint(R_symbol, num_columns=NUM_COLS)
1623        
1624        # ========================================================================
1625        # FIX: Always lambdify with all variables for consistency
1626        # ========================================================================
1627        if self.dim == 1:
1628            # Always include both x and xi in the signature
1629            R_func = lambdify((x, xi), R_symbol, modules='numpy')
1630        elif self.dim == 2:
1631            # Always include all four variables
1632            R_func = lambdify((x, y, xi, eta), R_symbol, modules='numpy')
1633        
1634        # Prepare right-hand side
1635        if self.source_terms:
1636            f_expr = sum(self.source_terms)
1637            used_vars = [v for v in spatial_vars if f_expr.has(v)]
1638            f_func = lambdify(used_vars, -f_expr, modules='numpy')
1639            if self.dim == 1:
1640                rhs = f_func(self.x_grid) if used_vars else np.zeros_like(self.x_grid)
1641            else:
1642                rhs = f_func(self.X, self.Y) if used_vars else np.zeros_like(self.X)
1643        elif self.initial_condition:
1644            raise ValueError('Initial condition should be None for stationnary equation.')
1645        else:
1646            raise ValueError('No source term provided to construct the right-hand side.')
1647        
1648        f_hat = self.fft(rhs)
1649        
1650        # ========================================================================
1651        # Application of the inverse operator
1652        # ========================================================================
1653        if self.boundary_condition == 'periodic':
1654            if self.dim == 1:
1655                # Check if optimization is possible
1656                if not R_symbol.has(x):
1657                    print('⚡ Optimization: symbol independent of x – direct product in Fourier.')
1658                    # Create wrapper that ignores x
1659                    def R_func_optimized(kx_val):
1660                        return R_func(0.0, kx_val)  # x=0 since it doesn't matter
1661                    
1662                    R_vals = R_func_optimized(self.KX)
1663                    u_hat = R_vals * f_hat
1664                    u = self.ifft(u_hat)
1665                else:
1666                    print('⚙️ 1D Kohn-Nirenberg Quantification')
1667                    from psiop import kohn_nirenberg_fft
1668                    u = kohn_nirenberg_fft(
1669                        u_vals=rhs,
1670                        symbol_func=R_func,  # Now has correct signature (x, xi)
1671                        x_grid=self.x_grid,
1672                        kx=self.kx,
1673                        fft_func=self.fft,
1674                        ifft_func=self.ifft,
1675                        dim=1
1676                    )
1677                    
1678            elif self.dim == 2:
1679                if not R_symbol.has(x) and not R_symbol.has(y):
1680                    print('⚡ Optimization: Symbol independent of x and y – direct product in 2D Fourier.')
1681                    # Create wrapper that ignores x, y
1682                    def R_func_optimized(kx_val, ky_val):
1683                        return R_func(0.0, 0.0, kx_val, ky_val)
1684                    
1685                    R_vals = R_func_optimized(self.KX, self.KY)
1686                    u_hat = R_vals * f_hat
1687                    u = self.ifft(u_hat)
1688                else:
1689                    print('⚙️ 2D Kohn-Nirenberg Quantification')
1690                    from psiop import kohn_nirenberg_fft
1691                    u = kohn_nirenberg_fft(
1692                        u_vals=rhs,
1693                        symbol_func=R_func,  # Now has correct signature (x, y, xi, eta)
1694                        x_grid=self.x_grid,
1695                        kx=self.kx,
1696                        fft_func=self.fft,
1697                        ifft_func=self.ifft,
1698                        dim=2,
1699                        y_grid=self.y_grid,
1700                        ky=self.ky
1701                    )
1702            self.u = u
1703            return u
1704            
1705        elif self.boundary_condition == 'dirichlet':
1706            from psiop import kohn_nirenberg_nonperiodic
1707            
1708            if self.dim == 1:
1709                u = kohn_nirenberg_nonperiodic(
1710                    u_vals=rhs,
1711                    x_grid=self.x_grid,
1712                    xi_grid=self.kx,
1713                    symbol_func=R_func  # Now has correct signature (x, xi)
1714                )
1715            elif self.dim == 2:
1716                u = kohn_nirenberg_nonperiodic(
1717                    u_vals=rhs,
1718                    x_grid=(self.x_grid, self.y_grid),
1719                    xi_grid=(self.kx, self.ky),
1720                    symbol_func=R_func  # Now has correct signature (x, y, xi, eta)
1721                )
1722            self.u = u
1723            return u
1724        
1725        else:
1726            raise ValueError(f"Invalid boundary condition '{self.boundary_condition}'. Supported types are 'periodic' and 'dirichlet'.")

Solve stationary pseudo-differential equations of the form P[u] = f(x) or P[u] = f(x,y) using asymptotic inversion.

This method computes the solution to a stationary (time-independent) pseudo-differential equation where the operator P is defined via symbolic expressions (psiOp). It constructs an asymptotic right inverse R such that P∘R ≈ Id, then applies it to the source term f using either direct Fourier multiplication (when the symbol is spatially independent) or Kohn–Nirenberg quantization (when spatial dependence is present).

The inversion is based on the principal symbol of the operator and its asymptotic expansion up to the given order. Ellipticity of the symbol is checked numerically before inversion to ensure well-posedness.

Parameters

order : int, default=3 Order of the asymptotic expansion used to construct the right inverse of the pseudo-differential operator. method : str, optional Inversion strategy: - 'diagonal' (default): Fast approximate inversion using diagonal operators in frequency space. - 'full' : Pointwise exact inversion (slower but more accurate).

Returns

ndarray The computed solution u(x) in 1D or u(x, y) in 2D as a NumPy array over the spatial grid.

Raises

ValueError If no pseudo-differential operator (psiOp) is defined. If linear or nonlinear terms other than psiOp are present. If the symbol is not elliptic on the grid. If no source term is provided for the right-hand side.

Notes

  • The method assumes the problem is fully stationary: time derivatives must be absent.
  • Requires the equation to be purely pseudo-differential (no Op, Derivative, or nonlinear terms).
  • Symbol evaluation and inversion are dimension-aware (supports both 1D and 2D problems).
  • Supports optimization paths when the symbol does not depend on spatial variables.

See Also

right_inverse_asymptotic : Constructs the asymptotic inverse of the pseudo-differential operator. kohn_nirenberg : Numerical implementation of general pseudo-differential operators. is_elliptic_numerically : Verifies numerical ellipticity of the symbol.

def step_ETD_RK4(self, u):
1728    def step_ETD_RK4(self, u):
1729        """
1730        Perform one Exponential Time Differencing Runge-Kutta of 4th order (ETD-RK4) time step 
1731        for first-order in time PDEs of the form:
1732        
1733            ∂ₜu = L u + N(u)
1734        
1735        where L is a linear operator (possibly nonlocal or pseudo-differential), and N is a 
1736        nonlinear term treated via pseudo-spectral methods. This method evaluates the 
1737        exponential integrator up to fourth-order accuracy in time.
1738    
1739        The ETD-RK4 scheme uses four stages to approximate the integral of the variation-of-constants formula:
1740        
1741            uⁿ⁺¹ = e^(L Δt) uⁿ + Δt ∫₀¹ e^(L Δt (1 - τ)) φ(N(u(τ))) dτ
1742        
1743        where φ denotes the nonlinear contributions evaluated at intermediate stages.
1744    
1745        Parameters
1746            u (np.ndarray): Current solution in real space (physical grid values).
1747    
1748        Returns:
1749            np.ndarray: Updated solution in real space after one ETD-RK4 time step.
1750    
1751        Notes:
1752        - The linear part L is diagonal in Fourier space and precomputed as self.L(k).
1753        - Nonlinear terms are evaluated in physical space and transformed via FFT.
1754        - The functions φ₁(z) and φ₂(z) are entire functions arising from the ETD scheme:
1755          
1756              φ₁(z) = (eᶻ - 1)/z   if z ≠ 0
1757                     = 1            if z = 0
1758    
1759              φ₂(z) = (eᶻ - 1 - z)/z²   if z ≠ 0
1760                     = ½              if z = 0
1761    
1762        - This implementation assumes periodic boundary conditions and uses spectral differentiation via FFT.
1763        - See Hochbruck & Ostermann (2010) for theoretical background on exponential integrators.
1764    
1765        See Also:
1766            step_ETD_RK4_order2 : For second-order in time equations.
1767            psiOp_apply           : For applying pseudo-differential operators.
1768            apply_nonlinear      : For handling nonlinear terms in the PDE.
1769        """
1770        dt = self.dt
1771        L_fft = self.L(self.KX) if self.dim == 1 else self.L(self.KX, self.KY)
1772    
1773        E  = np.exp(dt * L_fft)
1774        E2 = np.exp(dt * L_fft / 2)
1775    
1776        def phi1(z):
1777            return np.where(np.abs(z) > 1e-12, (np.exp(z) - 1) / z, 1.0)
1778    
1779        def phi2(z):
1780            return np.where(np.abs(z) > 1e-12, (np.exp(z) - 1 - z) / z**2, 0.5)
1781    
1782        phi1_dtL = phi1(dt * L_fft)
1783        phi2_dtL = phi2(dt * L_fft)
1784    
1785        fft = self.fft
1786        ifft = self.ifft
1787    
1788        u_hat = fft(u)
1789        N1 = fft(self.apply_nonlinear(u))
1790    
1791        a = ifft(E2 * (u_hat + 0.5 * dt * N1 * phi1_dtL))
1792        N2 = fft(self.apply_nonlinear(a))
1793    
1794        b = ifft(E2 * (u_hat + 0.5 * dt * N2 * phi1_dtL))
1795        N3 = fft(self.apply_nonlinear(b))
1796    
1797        c = ifft(E * (u_hat + dt * N3 * phi1_dtL))
1798        N4 = fft(self.apply_nonlinear(c))
1799    
1800        u_new_hat = E * u_hat + dt * (
1801            N1 * phi1_dtL + 2 * (N2 + N3) * phi2_dtL + N4 * phi1_dtL
1802        ) / 6
1803    
1804        return ifft(u_new_hat)

Perform one Exponential Time Differencing Runge-Kutta of 4th order (ETD-RK4) time step for first-order in time PDEs of the form:

∂ₜu = L u + N(u)

where L is a linear operator (possibly nonlocal or pseudo-differential), and N is a nonlinear term treated via pseudo-spectral methods. This method evaluates the exponential integrator up to fourth-order accuracy in time.

The ETD-RK4 scheme uses four stages to approximate the integral of the variation-of-constants formula:

uⁿ⁺¹ = e^(L Δt) uⁿ + Δt ∫₀¹ e^(L Δt (1 - τ)) φ(N(u(τ))) dτ

where φ denotes the nonlinear contributions evaluated at intermediate stages.

Parameters u (np.ndarray): Current solution in real space (physical grid values).

Returns: np.ndarray: Updated solution in real space after one ETD-RK4 time step.

Notes:

  • The linear part L is diagonal in Fourier space and precomputed as self.L(k).
  • Nonlinear terms are evaluated in physical space and transformed via FFT.
  • The functions φ₁(z) and φ₂(z) are entire functions arising from the ETD scheme:
  φ₁(z) = (eᶻ - 1)/z   if z ≠ 0
         = 1            if z = 0

  φ₂(z) = (eᶻ - 1 - z)/z²   if z ≠ 0
         = ½              if z = 0
  • This implementation assumes periodic boundary conditions and uses spectral differentiation via FFT.
  • See Hochbruck & Ostermann (2010) for theoretical background on exponential integrators.

See Also: step_ETD_RK4_order2 : For second-order in time equations. psiOp_apply : For applying pseudo-differential operators. apply_nonlinear : For handling nonlinear terms in the PDE.

def step_ETD_RK4_order2(self, u, v):
1806    def step_ETD_RK4_order2(self, u, v):
1807        """
1808        Perform one time step of the Exponential Time Differencing Runge-Kutta 4th-order (ETD-RK4) scheme for second-order PDEs.
1809    
1810        This method evolves the solution u and its time derivative v forward in time by one step using the ETD-RK4 integrator. 
1811        It is designed for systems of the form:
1812        
1813            ∂ₜ²u = L u + N(u)
1814            
1815        where L is a linear operator and N is a nonlinear term computed via self.apply_nonlinear.
1816        
1817        The exponential integrator handles the linear part exactly in Fourier space, while the nonlinear terms are integrated 
1818        using a fourth-order Runge-Kutta-like approach. This ensures high accuracy and stability for stiff systems.
1819    
1820        Parameters:
1821            u (np.ndarray): Current solution array in real space.
1822            v (np.ndarray): Current time derivative of the solution (∂ₜu) in real space.
1823    
1824        Returns:
1825            tuple: (u_new, v_new), updated solution and its time derivative after one time step.
1826    
1827        Notes:
1828            - Assumes periodic boundary conditions and uses FFT-based spectral methods.
1829            - Handles both 1D and 2D problems seamlessly.
1830            - Uses phi functions to compute exponential integrators efficiently.
1831            - Suitable for wave equations and other second-order evolution equations with stiffness.
1832        """
1833        dt = self.dt
1834    
1835        L_fft = self.L(self.KX) if self.dim == 1 else self.L(self.KX, self.KY)
1836        fft = self.fft
1837        ifft = self.ifft
1838    
1839        def rhs(u_val):
1840            return ifft(L_fft * fft(u_val)) + self.apply_nonlinear(u_val, is_v=False)
1841    
1842        # Stage A
1843        A = rhs(u)
1844        ua = u + 0.5 * dt * v
1845        va = v + 0.5 * dt * A
1846    
1847        # Stage B
1848        B = rhs(ua)
1849        ub = u + 0.5 * dt * va
1850        vb = v + 0.5 * dt * B
1851    
1852        # Stage C
1853        C = rhs(ub)
1854        uc = u + dt * vb
1855    
1856        # Stage D
1857        D = rhs(uc)
1858    
1859        # Final update
1860        u_new = u + dt * v + (dt**2 / 6.0) * (A + 2*B + 2*C + D)
1861        v_new = v + (dt / 6.0) * (A + 2*B + 2*C + D)
1862    
1863        return u_new, v_new

Perform one time step of the Exponential Time Differencing Runge-Kutta 4th-order (ETD-RK4) scheme for second-order PDEs.

This method evolves the solution u and its time derivative v forward in time by one step using the ETD-RK4 integrator. It is designed for systems of the form:

∂ₜ²u = L u + N(u)

where L is a linear operator and N is a nonlinear term computed via self.apply_nonlinear.

The exponential integrator handles the linear part exactly in Fourier space, while the nonlinear terms are integrated using a fourth-order Runge-Kutta-like approach. This ensures high accuracy and stability for stiff systems.

Parameters: u (np.ndarray): Current solution array in real space. v (np.ndarray): Current time derivative of the solution (∂ₜu) in real space.

Returns: tuple: (u_new, v_new), updated solution and its time derivative after one time step.

Notes: - Assumes periodic boundary conditions and uses FFT-based spectral methods. - Handles both 1D and 2D problems seamlessly. - Uses phi functions to compute exponential integrators efficiently. - Suitable for wave equations and other second-order evolution equations with stiffness.

def check_cfl_condition(self):
1865    def check_cfl_condition(self):
1866        """
1867        Check the CFL (Courant–Friedrichs–Lewymann) condition based on group velocity 
1868        for second-order time-dependent PDEs.
1869    
1870        This method verifies whether the chosen time step dt satisfies the numerical stability 
1871        condition derived from the maximum wave propagation speed in the system. It supports both 
1872        1D and 2D problems, with or without a symbolic dispersion relation ω(k).
1873    
1874        The CFL condition ensures that information does not propagate further than one grid cell 
1875        per time step. A safety factor of 0.5 is applied by default to ensure robustness.
1876    
1877        Notes:
1878        
1879        - In 1D, the group velocity v₉(k) = dω/dk is used to compute the maximum wave speed.
1880        - In 2D, the x- and y-directional group velocities are evaluated independently.
1881        - If no dispersion relation is available, the imaginary part of the linear operator L(k) 
1882          is used as an approximation for wave speed.
1883    
1884        Raises:
1885        -------
1886        NotImplementedError: 
1887            If the spatial dimension is not 1D or 2D.
1888    
1889        Prints:
1890        -------
1891        Warning message if the current time step dt exceeds the CFL-stable limit.
1892        """
1893        print("\n*****************")
1894        print("* CFL condition *")
1895        print("*****************\n")
1896
1897        cfl_factor = 0.5  # Safety factor
1898        
1899        if self.dim == 1:
1900            if self.temporal_order == 2 and hasattr(self, 'omega'):
1901                k_vals = self.kx
1902                omega_vals = np.real(self.omega(k_vals))
1903                with np.errstate(divide='ignore', invalid='ignore'):
1904                    v_group = np.gradient(omega_vals, k_vals)
1905                max_speed = np.max(np.abs(v_group))
1906            else:
1907                max_speed = np.max(np.abs(np.imag(self.L(self.kx))))
1908            
1909            dx = self.Lx / self.Nx
1910            cfl_limit = cfl_factor * dx / max_speed if max_speed != 0 else np.inf
1911            
1912            if self.dt > cfl_limit:
1913                print(f"CFL condition violated: dt = {self.dt}, max allowed dt = {cfl_limit}")
1914    
1915        elif self.dim == 2:
1916            if self.temporal_order == 2 and hasattr(self, 'omega'):
1917                k_vals = self.kx
1918                omega_x = np.real(self.omega(k_vals, 0))
1919                omega_y = np.real(self.omega(0, k_vals))
1920                with np.errstate(divide='ignore', invalid='ignore'):
1921                    v_group_x = np.gradient(omega_x, k_vals)
1922                    v_group_y = np.gradient(omega_y, k_vals)
1923                max_speed_x = np.max(np.abs(v_group_x))
1924                max_speed_y = np.max(np.abs(v_group_y))
1925            else:
1926                max_speed_x = np.max(np.abs(np.imag(self.L(self.kx, 0))))
1927                max_speed_y = np.max(np.abs(np.imag(self.L(0, self.ky))))
1928            
1929            dx = self.Lx / self.Nx
1930            dy = self.Ly / self.Ny
1931            cfl_limit = cfl_factor / (max_speed_x / dx + max_speed_y / dy) if (max_speed_x + max_speed_y) != 0 else np.inf
1932            
1933            if self.dt > cfl_limit:
1934                print(f"CFL condition violated: dt = {self.dt}, max allowed dt = {cfl_limit}")
1935    
1936        else:
1937            raise NotImplementedError("Only 1D and 2D problems are supported.")

Check the CFL (Courant–Friedrichs–Lewymann) condition based on group velocity for second-order time-dependent PDEs.

This method verifies whether the chosen time step dt satisfies the numerical stability condition derived from the maximum wave propagation speed in the system. It supports both 1D and 2D problems, with or without a symbolic dispersion relation ω(k).

The CFL condition ensures that information does not propagate further than one grid cell per time step. A safety factor of 0.5 is applied by default to ensure robustness.

Notes:

  • In 1D, the group velocity v₉(k) = dω/dk is used to compute the maximum wave speed.
  • In 2D, the x- and y-directional group velocities are evaluated independently.
  • If no dispersion relation is available, the imaginary part of the linear operator L(k) is used as an approximation for wave speed.

Raises:

NotImplementedError: If the spatial dimension is not 1D or 2D.

Prints:

Warning message if the current time step dt exceeds the CFL-stable limit.

def check_symbol_conditions(self, k_range=None, verbose=True):
1939    def check_symbol_conditions(self, k_range=None, verbose=True):
1940        """
1941        Check strict analytic conditions on the linear symbol self.L_symbolic:
1942            This method evaluates three key properties of the Fourier multiplier 
1943            symbol a(k) = self.L(k), which are crucial for well-posedness, stability,
1944            and numerical efficiency. The checks apply to both 1D and 2D cases.
1945        
1946        Conditions checked:
1947        ------------------
1948        1. **Stability condition**: Re(a(k)) ≤ 0 for all k ≠ 0
1949           Ensures that the system does not exhibit exponential growth in time.
1950    
1951        2. **Dissipation condition**: Re(a(k)) ≤ -δ |k|² for large |k|
1952           Ensures sufficient damping at high frequencies to avoid oscillatory instability.
1953    
1954        3. **Growth condition**: |a(k)| ≤ C (1 + |k|)^m with m ≤ 4
1955           Ensures that the symbol does not grow too rapidly with frequency, 
1956           which would otherwise cause numerical instability or unphysical amplification.
1957    
1958        Parameters
1959        ----------
1960        k_range : tuple or None, optional
1961            Specifies the range of frequencies to test in the form (k_min, k_max, N).
1962            If None, defaults are used: [-10, 10] with 500 points in 1D, or [-10, 10] 
1963            with 100 points per axis in 2D.
1964    
1965        verbose : bool, default=True
1966            If True, prints detailed results of each condition check.
1967    
1968        Returns:
1969        --------
1970        None
1971            Output is printed directly to the console for interpretability.
1972    
1973        Notes:
1974        ------
1975        - In 2D, the radial frequency |k| = √(kx² + ky²) is used for comparisons.
1976        - The dissipation threshold assumes δ = 0.01 and p = 2 by default.
1977        - The growth ratio is compared against |k|⁴; values above 100 indicate rapid growth.
1978        - This function is typically called during solver setup or analysis phase.
1979    
1980        See Also:
1981        ---------
1982        analyze_wave_propagation : For further symbolic and numerical analysis of dispersion.
1983        plot_symbol : Visualizes the symbol's behavior over the frequency domain.
1984        """
1985        print("\n********************")
1986        print("* Symbol condition *")
1987        print("********************\n")
1988
1989    
1990        if self.dim == 1:    
1991            if k_range is None:
1992                k_vals = np.linspace(-10, 10, 500)
1993            else:
1994                k_min, k_max, N = k_range
1995                k_vals = np.linspace(k_min, k_max, N)
1996    
1997            L_vals = self.L(k_vals)
1998            k_abs = np.abs(k_vals)
1999    
2000        elif self.dim == 2:
2001            if k_range is None:
2002                k_vals = np.linspace(-10, 10, 100)
2003            else:
2004                k_min, k_max, N = k_range
2005                k_vals = np.linspace(k_min, k_max, N)
2006    
2007            KX, KY = np.meshgrid(k_vals, k_vals)
2008            L_vals = self.L(KX, KY)
2009            k_abs = np.sqrt(KX**2 + KY**2)
2010    
2011        else:
2012            raise ValueError("Only 1D and 2D dimensions are supported.")
2013
2014    
2015        re_vals = np.real(L_vals)
2016        abs_vals = np.abs(L_vals)
2017    
2018        # === Condition 1: Stability
2019        if np.any(re_vals > 1e-12):
2020            max_pos = np.max(re_vals)
2021            if verbose:
2022                print(f"❌ Stability violated: max Re(a(k)) = {max_pos}")
2023            print("Unstable symbol: Re(a(k)) > 0")
2024        elif verbose:
2025            print("✅ Spectral stability satisfied: Re(a(k)) ≤ 0")
2026    
2027        # === Condition 2: Dissipation
2028        mask = k_abs > 2
2029        if np.any(mask):
2030            re_decay = re_vals[mask]
2031            expected_decay = -0.01 * k_abs[mask]**2
2032            if np.any(re_decay > expected_decay + 1e-6):
2033                if verbose:
2034                    print("⚠️ Insufficient high-frequency dissipation")
2035            else:
2036                if verbose:
2037                    print("✅ Proper high-frequency dissipation")
2038    
2039        # === Condition 3: Growth
2040        growth_ratio = abs_vals / (1 + k_abs)**4
2041        if np.max(growth_ratio) > 100:
2042            if verbose:
2043                print("⚠️ Symbol grows rapidly: |a(k)| ≳ |k|^4")
2044        else:
2045            if verbose:
2046                print("✅ Reasonable spectral growth")
2047    
2048        if verbose:
2049            print("✔ Symbol analysis completed.")

Check strict analytic conditions on the linear symbol self.L_symbolic: This method evaluates three key properties of the Fourier multiplier symbol a(k) = self.L(k), which are crucial for well-posedness, stability, and numerical efficiency. The checks apply to both 1D and 2D cases.

Conditions checked:

  1. Stability condition: Re(a(k)) ≤ 0 for all k ≠ 0 Ensures that the system does not exhibit exponential growth in time.

  2. Dissipation condition: Re(a(k)) ≤ -δ |k|² for large |k| Ensures sufficient damping at high frequencies to avoid oscillatory instability.

  3. Growth condition: |a(k)| ≤ C (1 + |k|)^m with m ≤ 4 Ensures that the symbol does not grow too rapidly with frequency, which would otherwise cause numerical instability or unphysical amplification.

Parameters

k_range : tuple or None, optional Specifies the range of frequencies to test in the form (k_min, k_max, N). If None, defaults are used: [-10, 10] with 500 points in 1D, or [-10, 10] with 100 points per axis in 2D.

verbose : bool, default=True If True, prints detailed results of each condition check.

Returns:

None Output is printed directly to the console for interpretability.

Notes:

  • In 2D, the radial frequency |k| = √(kx² + ky²) is used for comparisons.
  • The dissipation threshold assumes δ = 0.01 and p = 2 by default.
  • The growth ratio is compared against |k|⁴; values above 100 indicate rapid growth.
  • This function is typically called during solver setup or analysis phase.

See Also:

analyze_wave_propagation : For further symbolic and numerical analysis of dispersion. plot_symbol : Visualizes the symbol's behavior over the frequency domain.

def analyze_wave_propagation(self):
2051    def analyze_wave_propagation(self):
2052        """
2053        Perform a detailed analysis of wave propagation characteristics based on the dispersion relation ω(k).
2054    
2055        This method visualizes key wave properties in both 1D and 2D settings:
2056        
2057        - Dispersion relation: ω(k)
2058        - Phase velocity: v_p(k) = ω(k)/|k|
2059        - Group velocity: v_g(k) = ∇ₖ ω(k)
2060        - Anisotropy in 2D (via magnitude of group velocity)
2061    
2062        The symbolic dispersion relation 'omega_symbolic' must be defined beforehand.
2063        This is typically available only for second-order-in-time equations.
2064    
2065        In 1D:
2066            Plots ω(k), v_p(k), and v_g(k) over a range of k values.
2067    
2068        In 2D:
2069            Displays heatmaps of ω(kx, ky), v_p(kx, ky), and |v_g(kx, ky)| over a 2D wavenumber grid.
2070    
2071        Raises:
2072            AttributeError: If 'omega_symbolic' is not defined, the method exits gracefully with a message.
2073    
2074        Side Effects:
2075            Generates and displays matplotlib plots.
2076        """
2077        print("\n*****************************")
2078        print("* Wave propagation analysis *")
2079        print("*****************************\n")
2080        if not hasattr(self, 'omega_symbolic'):
2081            print("❌ omega_symbolic not defined. Only available for 2nd order in time.")
2082            return
2083    
2084        if self.dim == 1:
2085            k = self.k_symbols[0]
2086            omega_func = lambdify(k, self.omega_symbolic, 'numpy')
2087    
2088            k_vals = np.linspace(-10, 10, 1000)
2089            omega_vals = omega_func(k_vals)
2090    
2091            with np.errstate(divide='ignore', invalid='ignore'):
2092                v_phase = np.where(k_vals != 0, omega_vals / k_vals, 0.0)
2093    
2094            dk = k_vals[1] - k_vals[0]
2095            v_group = np.gradient(omega_vals, dk)
2096    
2097            plt.figure(figsize=(10, 6))
2098            plt.plot(k_vals, omega_vals, label=r'$\omega(k)$')
2099            plt.plot(k_vals, v_phase, label=r'$v_p(k)$')
2100            plt.plot(k_vals, v_group, label=r'$v_g(k)$')
2101            plt.title("1D Wave Propagation Analysis")
2102            plt.xlabel("k")
2103            plt.grid()
2104            plt.legend()
2105            plt.tight_layout()
2106            plt.show()
2107    
2108        elif self.dim == 2:
2109            kx, ky = self.k_symbols
2110            omega_func = lambdify((kx, ky), self.omega_symbolic, 'numpy')
2111    
2112            k_vals = np.linspace(-10, 10, 200)
2113            KX, KY = np.meshgrid(k_vals, k_vals)
2114            K_mag = np.sqrt(KX**2 + KY**2)
2115            K_mag[K_mag == 0] = 1e-8  # Avoid division by 0
2116    
2117            omega_vals = omega_func(KX, KY)
2118            v_phase = np.real(omega_vals) / K_mag
2119    
2120            dk = k_vals[1] - k_vals[0]
2121            domega_dx = np.gradient(omega_vals, dk, axis=0)
2122            domega_dy = np.gradient(omega_vals, dk, axis=1)
2123            v_group_norm = np.sqrt(np.abs(domega_dx)**2 + np.abs(domega_dy)**2)
2124    
2125            fig, axs = plt.subplots(1, 3, figsize=(18, 5))
2126            im0 = axs[0].imshow(np.real(omega_vals), extent=[-10, 10, -10, 10],
2127                                origin='lower', cmap='viridis')
2128            axs[0].set_title(r'$\omega(k_x, k_y)$')
2129            plt.colorbar(im0, ax=axs[0])
2130    
2131            im1 = axs[1].imshow(v_phase, extent=[-10, 10, -10, 10],
2132                                origin='lower', cmap='plasma')
2133            axs[1].set_title(r'$v_p(k_x, k_y)$')
2134            plt.colorbar(im1, ax=axs[1])
2135    
2136            im2 = axs[2].imshow(v_group_norm, extent=[-10, 10, -10, 10],
2137                                origin='lower', cmap='inferno')
2138            axs[2].set_title(r'$|v_g(k_x, k_y)|$')
2139            plt.colorbar(im2, ax=axs[2])
2140    
2141            for ax in axs:
2142                ax.set_xlabel(r'$k_x$')
2143                ax.set_ylabel(r'$k_y$')
2144                ax.set_aspect('equal')
2145    
2146            plt.tight_layout()
2147            plt.show()
2148    
2149        else:
2150            print("❌ Only 1D and 2D wave analysis supported.")

Perform a detailed analysis of wave propagation characteristics based on the dispersion relation ω(k).

This method visualizes key wave properties in both 1D and 2D settings:

  • Dispersion relation: ω(k)
  • Phase velocity: v_p(k) = ω(k)/|k|
  • Group velocity: v_g(k) = ∇ₖ ω(k)
  • Anisotropy in 2D (via magnitude of group velocity)

The symbolic dispersion relation 'omega_symbolic' must be defined beforehand. This is typically available only for second-order-in-time equations.

In 1D: Plots ω(k), v_p(k), and v_g(k) over a range of k values.

In 2D: Displays heatmaps of ω(kx, ky), v_p(kx, ky), and |v_g(kx, ky)| over a 2D wavenumber grid.

Raises: AttributeError: If 'omega_symbolic' is not defined, the method exits gracefully with a message.

Side Effects: Generates and displays matplotlib plots.

def plot_symbol(self, component='abs', k_range=None, cmap='viridis'):
2152    def plot_symbol(self, component="abs", k_range=None, cmap="viridis"):
2153        """
2154        Visualize the spectral symbol L(k) or L(kx, ky) in 1D or 2D.
2155    
2156        This method plots the linear operator's symbolic Fourier representation 
2157        either as a function of a single wavenumber k (1D), or two wavenumbers 
2158        kx and ky (2D). The user can choose to display the real part, imaginary part, 
2159        or absolute value of the symbol.
2160    
2161        Parameters
2162        ----------
2163        component : str {'abs', 're', 'im'}
2164            Component of the symbol to visualize:
2165            
2166                - 'abs' : absolute value |a(k)|
2167                - 're'  : real part Re[a(k)]
2168                - 'im'  : imaginary part Im[a(k)]
2169                
2170        k_range : tuple (kmin, kmax, N), optional
2171            Wavenumber range for evaluation:
2172            
2173                - kmin: minimum wavenumber
2174                - kmax: maximum wavenumber
2175                - N: number of sampling points
2176                
2177            If None, defaults to [-10, 10] with high resolution.
2178        cmap : str, optional
2179            Colormap used for 2D surface plots. Default is 'viridis'.
2180    
2181        Raises
2182        ------
2183            ValueError: If the spatial dimension is not 1D or 2D.
2184    
2185        Notes:
2186            - In 1D, the symbol is plotted using a standard 2D line plot.
2187            - In 2D, a 3D surface plot is generated with color-mapped height.
2188            - Symbol evaluation uses self.L(k), which must be defined and callable.
2189        """
2190        print("\n*******************")
2191        print("* Symbol plotting *")
2192        print("*******************\n")
2193        
2194        assert component in ("abs", "re", "im"), "component must be 'abs', 're' or 'im'"
2195        
2196    
2197        if self.dim == 1:
2198            if k_range is None:
2199                k_vals = np.linspace(-10, 10, 1000)
2200            else:
2201                kmin, kmax, N = k_range
2202                k_vals = np.linspace(kmin, kmax, N)
2203            L_vals = self.L(k_vals)
2204    
2205            if component == "re":
2206                vals = np.real(L_vals)
2207                label = "Re[a(k)]"
2208            elif component == "im":
2209                vals = np.imag(L_vals)
2210                label = "Im[a(k)]"
2211            else:
2212                vals = np.abs(L_vals)
2213                label = "|a(k)|"
2214    
2215            plt.plot(k_vals, vals)
2216            plt.xlabel("k")
2217            plt.ylabel(label)
2218            plt.title(f"Spectral symbol: {label}")
2219            plt.grid(True)
2220            plt.show()
2221    
2222        elif self.dim == 2:
2223            if k_range is None:
2224                k_vals = np.linspace(-10, 10, 300)
2225            else:
2226                kmin, kmax, N = k_range
2227                k_vals = np.linspace(kmin, kmax, N)
2228    
2229            KX, KY = np.meshgrid(k_vals, k_vals)
2230            L_vals = self.L(KX, KY)
2231    
2232            if component == "re":
2233                Z = np.real(L_vals)
2234                title = "Re[a(kx, ky)]"
2235            elif component == "im":
2236                Z = np.imag(L_vals)
2237                title = "Im[a(kx, ky)]"
2238            else:
2239                Z = np.abs(L_vals)
2240                title = "|a(kx, ky)|"
2241    
2242            fig = plt.figure(figsize=(8, 6))
2243            ax = fig.add_subplot(111, projection='3d')
2244        
2245            surf = ax.plot_surface(KX, KY, Z, cmap=cmap, edgecolor='none', antialiased=True)
2246            fig.colorbar(surf, ax=ax, shrink=0.6)
2247        
2248            ax.set_xlabel("kx")
2249            ax.set_ylabel("ky")
2250            ax.set_zlabel(title)
2251            ax.set_title(f"2D spectral symbol: {title}")
2252            plt.tight_layout()
2253            plt.show()
2254    
2255        else:
2256            raise ValueError("Only 1D and 2D supported.")

Visualize the spectral symbol L(k) or L(kx, ky) in 1D or 2D.

This method plots the linear operator's symbolic Fourier representation either as a function of a single wavenumber k (1D), or two wavenumbers kx and ky (2D). The user can choose to display the real part, imaginary part, or absolute value of the symbol.

Parameters

component : str {'abs', 're', 'im'} Component of the symbol to visualize:

    - 'abs' : absolute value |a(k)|
    - 're'  : real part Re[a(k)]
    - 'im'  : imaginary part Im[a(k)]

k_range : tuple (kmin, kmax, N), optional Wavenumber range for evaluation:

    - kmin: minimum wavenumber
    - kmax: maximum wavenumber
    - N: number of sampling points

If None, defaults to [-10, 10] with high resolution.

cmap : str, optional Colormap used for 2D surface plots. Default is 'viridis'.

Raises

ValueError: If the spatial dimension is not 1D or 2D.

Notes: - In 1D, the symbol is plotted using a standard 2D line plot. - In 2D, a 3D surface plot is generated with color-mapped height. - Symbol evaluation uses self.L(k), which must be defined and callable.

def compute_energy(self):
2258    def compute_energy(self):
2259        """
2260        Compute the total energy of the wave equation solution for second-order temporal PDEs. 
2261        The energy is defined as:
2262            E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹ᐟ²u|² ] dx
2263        where L is the linear operator associated with the spatial part of the PDE,
2264        and L¹ᐟ² denotes its square root in Fourier space.
2265    
2266        This method supports both 1D and 2D problems and is only meaningful when 
2267        self.temporal_order == 2 (second-order time derivative).
2268    
2269        Returns
2270        -------
2271        float or None: 
2272            Total energy at current time step. Returns None if the temporal order is not 2 or if no valid velocity data (v_prev) is available.
2273    
2274        Notes
2275        -----
2276        - Uses FFT-based spectral differentiation to compute the spatial contributions.
2277        - Assumes periodic boundary conditions.
2278        - Handles both real and complex-valued solutions.
2279        """
2280        if self.temporal_order != 2 or self.v_prev is None:
2281            return None
2282    
2283        u = self.u_prev
2284        v = self.v_prev
2285    
2286        # Fourier transform of u
2287        u_hat = self.fft(u)
2288    
2289        if self.dim == 1:
2290            # 1D case
2291            L_vals = self.L(self.KX)
2292            sqrt_L = np.sqrt(np.abs(L_vals))
2293            Lu_hat = sqrt_L * u_hat  # Apply sqrt(|L(k)|) in Fourier space
2294            Lu = self.ifft(Lu_hat)
2295    
2296            dx = self.Lx / self.Nx
2297            energy_density = 0.5 * (np.abs(v)**2 + np.abs(Lu)**2)
2298            total_energy = np.sum(energy_density) * dx
2299    
2300        elif self.dim == 2:
2301            # 2D case
2302            L_vals = self.L(self.KX, self.KY)
2303            sqrt_L = np.sqrt(np.abs(L_vals))
2304            Lu_hat = sqrt_L * u_hat
2305            Lu = self.ifft(Lu_hat)
2306    
2307            dx = self.Lx / self.Nx
2308            dy = self.Ly / self.Ny
2309            energy_density = 0.5 * (np.abs(v)**2 + np.abs(Lu)**2)
2310            total_energy = np.sum(energy_density) * dx * dy
2311    
2312        else:
2313            raise ValueError("Unsupported dimension for u.")
2314    
2315        return total_energy

Compute the total energy of the wave equation solution for second-order temporal PDEs. The energy is defined as: E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹ᐟ²u|² ] dx where L is the linear operator associated with the spatial part of the PDE, and L¹ᐟ² denotes its square root in Fourier space.

This method supports both 1D and 2D problems and is only meaningful when self.temporal_order == 2 (second-order time derivative).

Returns

float or None: Total energy at current time step. Returns None if the temporal order is not 2 or if no valid velocity data (v_prev) is available.

Notes

  • Uses FFT-based spectral differentiation to compute the spatial contributions.
  • Assumes periodic boundary conditions.
  • Handles both real and complex-valued solutions.
def plot_energy(self, log=False):
2317    def plot_energy(self, log=False):
2318        """
2319        Plot the time evolution of the total energy for wave equations. 
2320        Visualizes the energy computed during simulation for both 1D and 2D cases. 
2321        Requires temporal_order=2 and prior execution of compute_energy() during solve().
2322        
2323        Parameters:
2324            log : bool
2325                If True, displays energy on a logarithmic scale to highlight exponential decay/growth.
2326        
2327        Notes:
2328            - Energy is defined as E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹⸍²u|² ] dx
2329            - Only available if energy monitoring was activated in solve()
2330            - Automatically skips plotting if no energy data is available
2331        
2332        Displays:
2333            - Time vs. Total Energy plot with grid and legend
2334            - Appropriate axis labels and dimensional context (1D/2D)
2335            - Logarithmic or linear scaling based on input parameter
2336        """
2337        if not hasattr(self, 'energy_history') or not self.energy_history:
2338            print("No energy data recorded. Call compute_energy() within solve().")
2339            return
2340    
2341        # Time vector for plotting
2342        t = np.linspace(0, self.Lt, len(self.energy_history))
2343    
2344        # Create the figure
2345        plt.figure(figsize=(6, 4))
2346        if log:
2347            plt.semilogy(t, self.energy_history, label="Energy (log scale)")
2348        else:
2349            plt.plot(t, self.energy_history, label="Energy")
2350    
2351        # Axis labels and title
2352        plt.xlabel("Time")
2353        plt.ylabel("Total energy")
2354        plt.title("Energy evolution ({}D)".format(self.dim))
2355    
2356        # Display options
2357        plt.grid(True)
2358        plt.legend()
2359        plt.tight_layout()
2360        plt.show()

Plot the time evolution of the total energy for wave equations. Visualizes the energy computed during simulation for both 1D and 2D cases. Requires temporal_order=2 and prior execution of compute_energy() during solve().

Parameters: log : bool If True, displays energy on a logarithmic scale to highlight exponential decay/growth.

Notes: - Energy is defined as E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹⸍²u|² ] dx - Only available if energy monitoring was activated in solve() - Automatically skips plotting if no energy data is available

Displays: - Time vs. Total Energy plot with grid and legend - Appropriate axis labels and dimensional context (1D/2D) - Logarithmic or linear scaling based on input parameter

def show_stationary_solution(self, u=None, component='abs', cmap='viridis'):
2362    def show_stationary_solution(self, u=None, component='abs', cmap='viridis'):
2363        """
2364        Display the stationary solution computed by solve_stationary_psiOp.
2365
2366        This method visualizes the solution of a pseudo-differential equation 
2367        solved in stationary mode. It supports both 1D and 2D spatial domains, 
2368        with options to display different components of the solution (real, 
2369        imaginary, absolute value, or phase).
2370
2371        Parameters
2372        ----------
2373        u : ndarray, optional
2374            Precomputed solution array. If None, calls solve_stationary_psiOp() 
2375            to compute the solution.
2376        component : str, optional {'real', 'imag', 'abs', 'angle'}
2377            Component of the complex-valued solution to display:
2378            - 'real': Real part
2379            - 'imag': Imaginary part
2380            - 'abs' : Absolute value (modulus)
2381            - 'angle' : Phase (argument)
2382        cmap : str, optional
2383            Colormap used for 2D visualization (default: 'viridis').
2384
2385        Raises
2386        ------
2387        ValueError
2388            If an invalid component is specified or if the spatial dimension 
2389            is not supported (only 1D and 2D are implemented).
2390
2391        Notes
2392        -----
2393        - In 1D, the solution is displayed using a standard line plot.
2394        - In 2D, the solution is visualized as a 3D surface plot.
2395        """
2396        def get_component(u):
2397            if component == 'real':
2398                return np.real(u)
2399            elif component == 'imag':
2400                return np.imag(u)
2401            elif component == 'abs':
2402                return np.abs(u)
2403            elif component == 'angle':
2404                return np.angle(u)
2405            else:
2406                raise ValueError("Invalid component")
2407                
2408        if u is None:
2409            u = self.solve_stationary_psiOp()
2410
2411        if self.dim == 1:
2412            # Plot the solution in 1D
2413            plt.figure(figsize=(8, 4))
2414            plt.plot(self.x_grid, get_component(u), label=f'{component} of u')
2415            plt.xlabel('x')
2416            plt.ylabel(f'{component} of u')
2417            plt.title('Stationary solution (1D)')
2418            plt.grid(True)
2419            plt.legend()
2420            plt.tight_layout()
2421            plt.show()
2422    
2423        elif self.dim == 2:
2424            fig = plt.figure(figsize=(12, 6))
2425            ax = fig.add_subplot(111, projection='3d')
2426            ax.set_xlabel('x')
2427            ax.set_ylabel('y')
2428            ax.set_zlabel(f'{component.title()} of u')
2429            plt.title('Stationary solution (2D)')    
2430            data0 = get_component(u)
2431            ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2432            plt.tight_layout()
2433            plt.show()
2434    
2435        else:
2436            raise ValueError("Only 1D and 2D display are supported.")

Display the stationary solution computed by solve_stationary_psiOp.

This method visualizes the solution of a pseudo-differential equation solved in stationary mode. It supports both 1D and 2D spatial domains, with options to display different components of the solution (real, imaginary, absolute value, or phase).

Parameters

u : ndarray, optional Precomputed solution array. If None, calls solve_stationary_psiOp() to compute the solution. component : str, optional {'real', 'imag', 'abs', 'angle'} Component of the complex-valued solution to display: - 'real': Real part - 'imag': Imaginary part - 'abs' : Absolute value (modulus) - 'angle' : Phase (argument) cmap : str, optional Colormap used for 2D visualization (default: 'viridis').

Raises

ValueError If an invalid component is specified or if the spatial dimension is not supported (only 1D and 2D are implemented).

Notes

  • In 1D, the solution is displayed using a standard line plot.
  • In 2D, the solution is visualized as a 3D surface plot.
def animate(self, component='abs', overlay='contour', mode='surface'):
2438    def animate(self, component='abs', overlay='contour', mode='surface'):
2439        """
2440        Create an animated plot of the solution evolution over time.
2441    
2442        This method generates a dynamic visualization of the stored solution frames
2443        `self.frames`. It supports:
2444          - 1D line animation (unchanged),
2445          - 2D surface animation (original behavior, 'surface'),
2446          - 2D image animation using imshow (new, 'imshow') which is faster and
2447            often clearer for large grids.
2448    
2449        Parameters
2450        ----------
2451        component : str, optional, one of {'real', 'imag', 'abs', 'angle'}
2452            Which component of the complex field to visualize:
2453              - 'real'  : Re(u)
2454              - 'imag'  : Im(u)
2455              - 'abs'   : |u|
2456              - 'angle' : arg(u)
2457            Default is 'abs'.
2458    
2459        overlay : str or None, optional, one of {'contour', 'front', None}
2460            For 2D modes only. If None, no overlay is drawn.
2461              - 'contour' : draw contour lines on top (or beneath for 3D surface)
2462              - 'front'   : detect and mark wavefronts using gradient maxima
2463            Default is 'contour'.
2464    
2465        mode : str, optional, one of {'surface', 'imshow'}
2466            2D rendering mode. 'surface' keeps the original 3D surface plot.
2467            'imshow' draws a 2D raster (faster, often more readable).
2468            Default is 'surface' for backward compatibility.
2469    
2470        Returns
2471        -------
2472        FuncAnimation
2473            A Matplotlib `FuncAnimation` instance (you can display it in a notebook
2474            or save it to file).
2475    
2476        Notes
2477        -----
2478        - The method uses the same time-mapping logic as before (linear sampling of
2479          stored frames to animation frames).
2480        - For 'angle' the color scale is fixed between -π and π.
2481        - For other components, color scaling is by default dynamically adapted per
2482          frame in 'imshow' mode (this avoids extreme clipping if amplitudes vary).
2483        - Overlays are updated cleanly: previous contour/scatter artists are removed
2484          before drawing the next frame to avoid memory/visual accumulation.
2485        - Animation interval is 50 ms per frame (unchanged).
2486        """
2487        def get_component(u):
2488            if component == 'real':
2489                return np.real(u)
2490            elif component == 'imag':
2491                return np.imag(u)
2492            elif component == 'abs':
2493                return np.abs(u)
2494            elif component == 'angle':
2495                return np.angle(u)
2496            else:
2497                raise ValueError("Invalid component: choose 'real','imag','abs' or 'angle'")
2498    
2499        print("\n*********************")
2500        print("* Solution plotting *")
2501        print("*********************\n")
2502    
2503        # === Calculate time vector of stored frames ===
2504        save_interval = max(1, self.Nt // self.n_frames)
2505        frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2506    
2507        # === Target times for animation ===
2508        target_times = np.linspace(0, self.Lt, self.n_frames // 2)
2509    
2510        # Map target times to nearest frame indices
2511        frame_indices = [np.argmin(np.abs(frame_times - t)) for t in target_times]
2512    
2513        # -------------------------
2514        # 1D case (unchanged logic)
2515        # -------------------------
2516        if self.dim == 1:
2517            fig, ax = plt.subplots()
2518            initial = get_component(self.frames[0])
2519            line, = ax.plot(self.X, np.real(initial) if np.iscomplexobj(initial) else initial)
2520            ax.set_ylim(np.min(initial), np.max(initial))
2521            ax.set_xlabel('x')
2522            ax.set_ylabel(f'{component} of u')
2523            ax.set_title('Initial condition')
2524            plt.tight_layout()
2525    
2526            def update_1d(frame_number):
2527                frame = frame_indices[frame_number]
2528                ydata = get_component(self.frames[frame])
2529                ydata_real = np.real(ydata) if np.iscomplexobj(ydata) else ydata
2530                line.set_ydata(ydata_real)
2531                ax.set_ylim(np.min(ydata_real), np.max(ydata_real))
2532                current_time = target_times[frame_number]
2533                ax.set_title(f't = {current_time:.2f}')
2534                return (line,)
2535    
2536            ani = FuncAnimation(fig, update_1d, frames=len(target_times), interval=50)
2537            return ani
2538    
2539        # -------------------------
2540        # 2D case
2541        # -------------------------
2542        # Validate mode
2543        if mode not in ('surface', 'imshow'):
2544            raise ValueError("Invalid mode: choose 'surface' or 'imshow'")
2545    
2546        # Common data
2547        data0 = get_component(self.frames[0])
2548    
2549        if mode == 'surface':
2550            # original surface behavior, but ensure clean updates
2551            fig = plt.figure(figsize=(14, 8))
2552            ax = fig.add_subplot(111, projection='3d')
2553            ax.set_xlabel('x')
2554            ax.set_ylabel('y')
2555            ax.set_zlabel(f'{component.title()} of u')
2556            ax.zaxis.labelpad = 0
2557            ax.set_title('Initial condition')
2558    
2559            surf = ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2560            plt.tight_layout()
2561    
2562            def update_surface(frame_number):
2563                frame = frame_indices[frame_number]
2564                current_data = get_component(self.frames[frame])
2565                z_offset = np.max(current_data) + 0.05 * (np.max(current_data) - np.min(current_data))
2566    
2567                ax.clear()
2568                surf_obj = ax.plot_surface(self.X, self.Y, current_data,
2569                                           cmap='viridis',
2570                                           vmin=(-np.pi if component == 'angle' else None),
2571                                           vmax=(np.pi if component == 'angle' else None))
2572                # overlays
2573                if overlay == 'contour':
2574                    # place contours slightly below the surface (use offset)
2575                    try:
2576                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool', offset=z_offset)
2577                    except Exception:
2578                        # fallback: simple contour without offset if not supported
2579                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2580    
2581                elif overlay == 'front':
2582                    dx = self.x_grid[1] - self.x_grid[0]
2583                    dy = self.y_grid[1] - self.y_grid[0]
2584                    # numpy.gradient: axis0 -> y spacing, axis1 -> x spacing
2585                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2586                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2587                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2588                    if np.max(grad_norm) > 0:
2589                        normalized = grad_norm[local_max] / np.max(grad_norm)
2590                    else:
2591                        normalized = np.zeros(np.count_nonzero(local_max))
2592                    colors = cm.plasma(normalized)
2593                    ax.scatter(self.X[local_max], self.Y[local_max],
2594                               z_offset * np.ones_like(self.X[local_max]),
2595                               color=colors, s=10, alpha=0.8)
2596    
2597                ax.set_xlabel('x')
2598                ax.set_ylabel('y')
2599                ax.set_zlabel(f'{component.title()} of u')
2600                current_time = target_times[frame_number]
2601                ax.set_title(f'Solution at t = {current_time:.2f}')
2602                return (surf_obj,)
2603    
2604            ani = FuncAnimation(fig, update_surface, frames=len(target_times), interval=50)
2605            return ani
2606    
2607        else:  # mode == 'imshow'
2608            fig, ax = plt.subplots(figsize=(7, 6))
2609            ax.set_xlabel('x')
2610            ax.set_ylabel('y')
2611            ax.set_title('Initial condition')
2612    
2613            # extent uses physical coordinates so axes show real x/y values
2614            extent = [self.x_grid[0], self.x_grid[-1], self.y_grid[0], self.y_grid[-1]]
2615    
2616            if component == 'angle':
2617                vmin, vmax = -np.pi, np.pi
2618                cmap = 'twilight'
2619            else:
2620                vmin, vmax = np.min(data0), np.max(data0)
2621                cmap = 'viridis'
2622    
2623            im = ax.imshow(data0, extent=extent, origin='lower', cmap=cmap,
2624                           vmin=vmin, vmax=vmax, aspect='auto')
2625            cbar = fig.colorbar(im, ax=ax)
2626            cbar.set_label(f"{component} of u")
2627            plt.tight_layout()
2628    
2629            # containers for dynamic overlay artists (stored on function object)
2630            # update_im.contour_art and update_im.scatter_art will be created dynamically
2631    
2632            def update_im(frame_number):
2633                frame = frame_indices[frame_number]
2634                current_data = get_component(self.frames[frame])
2635    
2636                # update raster
2637                im.set_data(current_data)
2638                if component != 'angle':
2639                    # dynamic per-frame scaling (keeps contrast when amplitude varies)
2640                    cmin = np.nanmin(current_data)
2641                    cmax = np.nanmax(current_data)
2642                    # avoid identical vmin==vmax
2643                    if cmax > cmin:
2644                        im.set_clim(cmin, cmax)
2645    
2646                # remove previous contour if exists
2647                if overlay == 'contour':
2648                    if hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2649                        for coll in update_im.contour_art.collections:
2650                            try:
2651                                coll.remove()
2652                            except Exception:
2653                                pass
2654                        update_im.contour_art = None
2655                    # draw new contours (use meshgrid coords)
2656                    try:
2657                        update_im.contour_art = ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2658                    except Exception:
2659                        # fallback: contour with axis coordinates (x_grid, y_grid)
2660                        Xc, Yc = np.meshgrid(self.x_grid, self.y_grid)
2661                        update_im.contour_art = ax.contour(Xc, Yc, current_data, levels=10, cmap='cool')
2662    
2663                # remove previous scatter if exists
2664                if overlay == 'front':
2665                    if hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2666                        try:
2667                            update_im.scatter_art.remove()
2668                        except Exception:
2669                            pass
2670                        update_im.scatter_art = None
2671    
2672                    dx = self.x_grid[1] - self.x_grid[0]
2673                    dy = self.y_grid[1] - self.y_grid[0]
2674                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2675                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2676                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2677                    if np.max(grad_norm) > 0:
2678                        normalized = grad_norm[local_max] / np.max(grad_norm)
2679                    else:
2680                        normalized = np.zeros(np.count_nonzero(local_max))
2681                    colors = cm.plasma(normalized)
2682                    update_im.scatter_art = ax.scatter(self.X[local_max], self.Y[local_max],
2683                                                       c=colors, s=10, alpha=0.8)
2684    
2685                current_time = target_times[frame_number]
2686                ax.set_title(f'Solution at t = {current_time:.2f}')
2687                # return main image plus any overlay artists present so Matplotlib can redraw them
2688                artists = [im]
2689                if overlay == 'contour' and hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2690                    artists.extend(update_im.contour_art.collections)
2691                if overlay == 'front' and hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2692                    artists.append(update_im.scatter_art)
2693                return tuple(artists)
2694    
2695            ani = FuncAnimation(fig, update_im, frames=len(target_times), interval=50)
2696            return ani

Create an animated plot of the solution evolution over time.

This method generates a dynamic visualization of the stored solution frames self.frames. It supports:

  • 1D line animation (unchanged),
  • 2D surface animation (original behavior, 'surface'),
  • 2D image animation using imshow (new, 'imshow') which is faster and often clearer for large grids.

Parameters

component : str, optional, one of {'real', 'imag', 'abs', 'angle'} Which component of the complex field to visualize: - 'real' : Re(u) - 'imag' : Im(u) - 'abs' : |u| - 'angle' : arg(u) Default is 'abs'.

overlay : str or None, optional, one of {'contour', 'front', None} For 2D modes only. If None, no overlay is drawn. - 'contour' : draw contour lines on top (or beneath for 3D surface) - 'front' : detect and mark wavefronts using gradient maxima Default is 'contour'.

mode : str, optional, one of {'surface', 'imshow'} 2D rendering mode. 'surface' keeps the original 3D surface plot. 'imshow' draws a 2D raster (faster, often more readable). Default is 'surface' for backward compatibility.

Returns

FuncAnimation A Matplotlib FuncAnimation instance (you can display it in a notebook or save it to file).

Notes

  • The method uses the same time-mapping logic as before (linear sampling of stored frames to animation frames).
  • For 'angle' the color scale is fixed between -π and π.
  • For other components, color scaling is by default dynamically adapted per frame in 'imshow' mode (this avoids extreme clipping if amplitudes vary).
  • Overlays are updated cleanly: previous contour/scatter artists are removed before drawing the next frame to avoid memory/visual accumulation.
  • Animation interval is 50 ms per frame (unchanged).
def test( self, u_exact, t_eval=None, norm='relative', threshold=0.01, component='real'):
2698    def test(self, u_exact, t_eval=None, norm='relative', threshold=1e-2, component='real'):
2699        """
2700        Test the solver against an exact solution.
2701
2702        This method quantitatively compares the numerical solution with a provided exact solution 
2703        at a specified time using either relative or absolute error norms. It supports both 
2704        stationary and time-dependent problems in 1D and 2D. If enabled, it also generates plots 
2705        of the solution, exact solution, and pointwise error.
2706
2707        Parameters
2708        ----------
2709        u_exact : callable
2710            Exact solution function taking spatial coordinates and optionally time as arguments.
2711        t_eval : float, optional
2712            Time at which to compare solutions. For non-stationary problems, defaults to final time Lt.
2713            Ignored for stationary problems.
2714        norm : str {'relative', 'absolute'}
2715            Type of error norm used in comparison.
2716        threshold : float
2717            Acceptable error threshold; raises an assertion if exceeded.
2718        plot : bool
2719            Whether to display visual comparison plots (default: True).
2720        component : str {'real', 'imag', 'abs'}
2721            Component of the solution to compare and visualize.
2722
2723        Raises
2724        ------
2725        ValueError
2726            If unsupported dimension is encountered or requested evaluation time exceeds simulation duration.
2727        AssertionError
2728            If computed error exceeds the given threshold.
2729
2730        Prints
2731        ------
2732        - Information about the closest available frame to the requested evaluation time.
2733        - Computed error value and comparison to threshold.
2734
2735        Notes
2736        -----
2737        - For time-dependent problems, the solution is extracted from precomputed frames.
2738        - Plots are adapted to spatial dimension: line plots for 1D, image plots for 2D.
2739        - The method ensures consistent handling of real, imaginary, and magnitude components.
2740        """
2741        if self.is_stationary:
2742            print("Testing a stationary solution.")
2743            u_num = self.u
2744    
2745            # Compute exact solution
2746            if self.dim == 1:
2747                u_ex = u_exact(self.X)
2748            elif self.dim == 2:
2749                u_ex = u_exact(self.X, self.Y)
2750            else:
2751                raise ValueError("Unsupported dimension.")
2752            actual_t = None
2753        else:
2754            if t_eval is None:
2755                t_eval = self.Lt
2756    
2757            save_interval = max(1, self.Nt // self.n_frames)
2758            frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2759            frame_index = np.argmin(np.abs(frame_times - t_eval))
2760            actual_t = frame_times[frame_index]
2761            print(f"Closest available time to t_eval={t_eval}: {actual_t}")
2762    
2763            if frame_index >= len(self.frames):
2764                raise ValueError(f"Time t = {t_eval} exceeds simulation duration.")
2765    
2766            u_num = self.frames[frame_index]
2767    
2768            # Compute exact solution at the actual time
2769            if self.dim == 1:
2770                u_ex = u_exact(self.X, actual_t)
2771            elif self.dim == 2:
2772                u_ex = u_exact(self.X, self.Y, actual_t)
2773            else:
2774                raise ValueError("Unsupported dimension.")
2775    
2776        # Select component
2777        if component == 'real':
2778            diff = np.real(u_num) - np.real(u_ex)
2779            ref = np.real(u_ex)
2780        elif component == 'imag':
2781            diff = np.imag(u_num) - np.imag(u_ex)
2782            ref = np.imag(u_ex)
2783        elif component == 'abs':
2784            diff = np.abs(u_num) - np.abs(u_ex)
2785            ref = np.abs(u_ex)
2786        else:
2787            raise ValueError("Invalid component.")
2788    
2789        # Compute error
2790        if norm == 'relative':
2791            error = np.linalg.norm(diff) / np.linalg.norm(ref)
2792        elif norm == 'absolute':
2793            error = np.linalg.norm(diff)
2794        else:
2795            raise ValueError("Unknown norm type.")
2796    
2797        label_time = f"t = {actual_t}" if actual_t is not None else ""
2798        print(f"Test error {label_time}: {error:.3e}")
2799        assert error < threshold, f"Error too large {label_time}: {error:.3e}"
2800    
2801        # Plot
2802        if self.plot:
2803            if self.dim == 1:
2804                plt.figure(figsize=(12, 6))
2805                plt.subplot(2, 1, 1)
2806                plt.plot(self.X, np.real(u_num), label='Numerical')
2807                plt.plot(self.X, np.real(u_ex), '--', label='Exact')
2808                plt.title(f'Solution {label_time}, error = {error:.2e}')
2809                plt.legend()
2810                plt.grid()
2811    
2812                plt.subplot(2, 1, 2)
2813                plt.plot(self.X, np.abs(diff), color='red')
2814                plt.title('Absolute Error')
2815                plt.grid()
2816                plt.tight_layout()
2817                plt.show()
2818            else:
2819                extent = [-self.Lx/2, self.Lx/2, -self.Ly/2, self.Ly/2]
2820                plt.figure(figsize=(15, 5))
2821                plt.subplot(1, 3, 1)
2822                plt.title("Numerical Solution")
2823                plt.imshow(np.abs(u_num), origin='lower', extent=extent, cmap='viridis')
2824                plt.colorbar()
2825    
2826                plt.subplot(1, 3, 2)
2827                plt.title("Exact Solution")
2828                plt.imshow(np.abs(u_ex), origin='lower', extent=extent, cmap='viridis')
2829                plt.colorbar()
2830    
2831                plt.subplot(1, 3, 3)
2832                plt.title(f"Error (Norm = {error:.2e})")
2833                plt.imshow(np.abs(diff), origin='lower', extent=extent, cmap='inferno')
2834                plt.colorbar()
2835                plt.tight_layout()
2836                plt.show()

Test the solver against an exact solution.

This method quantitatively compares the numerical solution with a provided exact solution at a specified time using either relative or absolute error norms. It supports both stationary and time-dependent problems in 1D and 2D. If enabled, it also generates plots of the solution, exact solution, and pointwise error.

Parameters

u_exact : callable Exact solution function taking spatial coordinates and optionally time as arguments. t_eval : float, optional Time at which to compare solutions. For non-stationary problems, defaults to final time Lt. Ignored for stationary problems. norm : str {'relative', 'absolute'} Type of error norm used in comparison. threshold : float Acceptable error threshold; raises an assertion if exceeded. plot : bool Whether to display visual comparison plots (default: True). component : str {'real', 'imag', 'abs'} Component of the solution to compare and visualize.

Raises

ValueError If unsupported dimension is encountered or requested evaluation time exceeds simulation duration. AssertionError If computed error exceeds the given threshold.

Prints

  • Information about the closest available frame to the requested evaluation time.
  • Computed error value and comparison to threshold.

Notes

  • For time-dependent problems, the solution is extracted from precomputed frames.
  • Plots are adapted to spatial dimension: line plots for 1D, image plots for 2D.
  • The method ensures consistent handling of real, imaginary, and magnitude components.
class LagrangianHamiltonianConverter:
 37class LagrangianHamiltonianConverter:
 38    """
 39    Bidirectional converter between Lagrangian and Hamiltonian (Legendre transform),
 40    with optional Legendre–Fenchel (convex conjugate) support and robust numeric fallback.
 41
 42    Main API:
 43      L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False,
 44             method="legendre", fenchel_opts=None)
 45
 46        - method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
 47        - If method == "fenchel_numeric" returns (H_repr, xi_vars, numeric_callable)
 48          otherwise returns (H_expr, xi_vars)
 49    """
 50
 51    _numeric_cache = {}
 52
 53    # --------------------
 54    # Utilities
 55    # --------------------
 56    @staticmethod
 57    def _is_quadratic_in_p(L_expr, p_vars):
 58        """
 59        Robust test: returns True only if L_expr is polynomial of degree ≤ 2 in each p_var.
 60        Falls back to False for non-polynomial expressions (Abs, sqrt, etc.).
 61        """
 62        for p in p_vars:
 63            # Quick test: is L polynomial in p?
 64            if not L_expr.is_polynomial(p):
 65                return False
 66            try:
 67                deg = sp.degree(L_expr, p)
 68            except Exception:
 69                return False
 70            if deg is None or deg > 2:
 71                return False
 72        return True
 73
 74    @staticmethod
 75    def _quadratic_legendre(L_expr, p_vars, xi_vars):
 76        """
 77        Analytic Legendre transform for quadratic L: L = 1/2 p^T A p + b^T p + c
 78        Returns (H_expr, sol_map) and raises ValueError if Hessian singular.
 79        """
 80        A = Matrix([[sp.diff(sp.diff(L_expr, p_i), p_j) for p_j in p_vars] for p_i in p_vars])
 81        grad = Matrix([sp.diff(L_expr, p) for p in p_vars])
 82        try:
 83            A_inv = A.inv()
 84        except Exception:
 85            raise ValueError("Quadratic analytic path: Hessian A is singular (non-invertible).")
 86        subs_zero = {p: 0 for p in p_vars}
 87        b_vec = grad.subs(subs_zero)
 88        xi_vec = Matrix(xi_vars)
 89        p_solution_vec = A_inv * (xi_vec - b_vec)
 90        sol = {p_vars[i]: sp.simplify(p_solution_vec[i]) for i in range(len(p_vars))}
 91        H_expr = sum(xi_vars[i] * sol[p_vars[i]] for i in range(len(p_vars))) - sp.simplify(L_expr.subs(sol))
 92        return sp.simplify(H_expr), sol
 93
 94    # ----------------------------
 95    # Numeric Legendre-Fenchel helpers
 96    # ----------------------------
 97    @staticmethod
 98    def _legendre_fenchel_1d_numeric_callable(L_func, p_bounds=(-10.0, 10.0), n_grid=2001, mode="auto",
 99                                             scipy_multistart=5):
100        """
101        Return a callable H_numeric(xi) = sup_p (xi*p - L(p)) for 1D L_func(p).
102        - L_func: callable p -> L(p)
103        - mode: "auto" | "scipy" | "grid"
104        """
105        pmin, pmax = float(p_bounds[0]), float(p_bounds[1])
106
107        def _compute_by_grid(xi):
108            grid = _np.linspace(pmin, pmax, int(n_grid))
109            Lvals = _np.array([float(L_func(p)) for p in grid], dtype=float)
110            S = xi * grid - Lvals
111            idx = int(_np.argmax(S))
112            return float(S[idx]), float(grid[idx])
113
114        def _compute_by_scipy(xi):
115            if not _HAS_SCIPY:
116                return _compute_by_grid(xi)
117
118            def negS(p):
119                p0 = float(p[0])
120                return -(xi * p0 - float(L_func(p0)))
121
122            best_val = -_math.inf
123            best_p = None
124            inits = _np.linspace(pmin, pmax, max(3, int(scipy_multistart)))
125            for x0 in inits:
126                try:
127                    res = _optimize.minimize(negS, x0=[float(x0)], bounds=[(pmin, pmax)], method="L-BFGS-B")
128                    if res.success:
129                        pstar = float(res.x[0])
130                        sval = float(xi * pstar - float(L_func(pstar)))
131                        if sval > best_val:
132                            best_val = sval
133                            best_p = pstar
134                except Exception:
135                    continue
136            if best_p is None:
137                return _compute_by_grid(xi)
138            return best_val, best_p
139
140        compute = _compute_by_scipy if (_HAS_SCIPY and mode != "grid") else _compute_by_grid
141
142        def H_numeric(xi_in):
143            xi_arr = _np.atleast_1d(xi_in).astype(float)
144            out = _np.empty_like(xi_arr, dtype=float)
145            for i, xi in enumerate(xi_arr):
146                val, _ = compute(float(xi))
147                out[i] = val
148            if _np.isscalar(xi_in):
149                return float(out[0])
150            return out
151
152        return H_numeric
153
154    @staticmethod
155    def _legendre_fenchel_nd_numeric_callable(L_func, dim, p_bounds, n_grid_per_dim=41, mode="auto",
156                                              scipy_multistart=10, multistart_restarts=8):
157        """
158        Return callable H_numeric(xi_vector) approximating sup_p (xi·p - L(p)) for dim>=2.
159        - L_func: callable p_vector -> L(p)
160        - p_bounds: tuple/list of per-dimension bounds
161        """
162        pmin_list, pmax_list = p_bounds
163        pmin = [float(v) for v in pmin_list]
164        pmax = [float(v) for v in pmax_list]
165
166        def compute_by_grid(xi_vec):
167            import itertools
168            grids = [_np.linspace(pmin[d], pmax[d], int(n_grid_per_dim)) for d in range(dim)]
169            best = -_math.inf
170            best_p = None
171            for pt in itertools.product(*grids):
172                pt_arr = _np.array(pt, dtype=float)
173                sval = float(_np.dot(xi_vec, pt_arr) - L_func(pt_arr))
174                if sval > best:
175                    best = sval
176                    best_p = pt_arr
177            return best, best_p
178
179        def compute_by_scipy(xi_vec):
180            if not _HAS_SCIPY:
181                return compute_by_grid(xi_vec)
182
183            def negS(p):
184                p = _np.asarray(p, dtype=float)
185                return - (float(_np.dot(xi_vec, p)) - float(L_func(p)))
186
187            best_val = -_math.inf
188            best_p = None
189            center = _np.array([(pmin[d] + pmax[d]) / 2.0 for d in range(dim)], dtype=float)
190            rng = _np.random.default_rng(123456)
191            inits = [center]
192            for k in range(multistart_restarts):
193                r = rng.random(dim)
194                start = _np.array([pmin[d] + r[d] * (pmax[d] - pmin[d]) for d in range(dim)], dtype=float)
195                inits.append(start)
196            for x0 in inits:
197                try:
198                    res = _optimize.minimize(negS, x0=x0, bounds=tuple((pmin[d], pmax[d]) for d in range(dim)),
199                                             method="L-BFGS-B")
200                    if res.success:
201                        pstar = _np.asarray(res.x, dtype=float)
202                        sval = float(_np.dot(xi_vec, pstar) - L_func(pstar))
203                        if sval > best_val:
204                            best_val = sval
205                            best_p = pstar
206                except Exception:
207                    continue
208            if best_p is None:
209                return compute_by_grid(xi_vec)
210            return best_val, best_p
211
212        compute = compute_by_scipy if (_HAS_SCIPY and mode != "grid") else compute_by_grid
213
214        def H_numeric(xi_in):
215            xi_arr = _np.atleast_2d(xi_in).astype(float)
216            if xi_arr.shape[-1] != dim:
217                xi_arr = xi_arr.reshape(-1, dim)
218            out = _np.empty((xi_arr.shape[0],), dtype=float)
219            for i, xivec in enumerate(xi_arr):
220                val, _ = compute(xivec)
221                out[i] = val
222            if out.shape[0] == 1:
223                return float(out[0])
224            return out
225
226        return H_numeric
227
228    # ----------------------------
229    # Main methods
230    # ----------------------------
231    @staticmethod
232    def L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False,
233               method="legendre", fenchel_opts=None):
234        """
235        Convert L(x,u,p) -> H(x,u,xi) with options for generalized Legendre (Fenchel).
236
237        Parameters:
238          - method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
239          - fenchel_opts: dict with options for numeric fenchel
240        """
241        dim = len(coords)
242        if dim == 1:
243            xi_vars = (sp.Symbol('xi', real=True),)
244        elif dim == 2:
245            xi_vars = (sp.Symbol('xi', real=True), sp.Symbol('eta', real=True))
246        else:
247            raise ValueError("Only 1D and 2D dimensions are supported.")
248
249        # Quadratic fast-path (symbolic)
250        if method in ("legendre", "fenchel_symbolic") and LagrangianHamiltonianConverter._is_quadratic_in_p(L_expr, p_vars):
251            try:
252                H_expr, sol = LagrangianHamiltonianConverter._quadratic_legendre(L_expr, p_vars, xi_vars)
253                if return_symbol_only:
254                    H_expr = H_expr.subs(u, 0)
255                return H_expr, xi_vars
256            except Exception:
257                if not force and method == "legendre":
258                    raise
259
260        # CLASSICAL LEGENDRE
261        if method == "legendre":
262            H_p = None
263            try:
264                H_p = sp.hessian(L_expr, p_vars)
265                det_H = sp.simplify(H_p.det())
266            except Exception:
267                det_H = None
268
269            if det_H is not None and det_H == 0 and not force:
270                raise ValueError("Legendre transform not invertible: Hessian singular. Use force=True or Fenchel method.")
271            if det_H is None and not force:
272                raise ValueError("Unable to verify Hessian determinant symbolically. Use force=True to attempt solve().")
273
274            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
275            sol_list = sp.solve(eqs, p_vars, dict=True)
276            if not sol_list:
277                if not force:
278                    raise ValueError("Unable to solve symbolic Legendre relations. Use force=True or Fenchel fallback.")
279            if sol_list:
280                sol = sol_list[0]
281                if isinstance(sol, tuple) and len(sol) == len(p_vars):
282                    sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
283                H_expr = sum(xi_vars[i]*sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
284                H_expr = sp.simplify(H_expr)
285                if return_symbol_only:
286                    H_expr = H_expr.subs(u, 0)
287                return H_expr, xi_vars
288            raise ValueError("Legendre inversion failed even with solve().")
289
290        # FENCHEL: symbolic attempt
291        # -----------------------------------------------------
292        #  Prevent symbolic Fenchel when L is non-differentiable
293        # -----------------------------------------------------
294        if method == "fenchel_symbolic":
295            if L_expr.has(sp.Abs) or L_expr.has(sp.sign) or any(
296                sp.diff(L_expr, p).has(sp.sign, sp.Abs) for p in p_vars
297            ):
298                raise ValueError(
299                    "Symbolic Fenchel not possible for nonsmooth L (Abs, sign). "
300                    "Use method='fenchel_numeric' instead."
301                )
302
303        if method == "fenchel_symbolic":
304            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
305            sol_list = sp.solve(eqs, p_vars, dict=True)
306            if sol_list:
307                candidates = []
308                for sol in sol_list:
309                    if isinstance(sol, tuple) and len(sol) == len(p_vars):
310                        sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
311                    S_expr = sum(xi_vars[i] * sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
312                    candidates.append(sp.simplify(S_expr))
313                H_candidates = sp.simplify(sp.Max(*candidates)) if len(candidates) > 1 else candidates[0]
314                if return_symbol_only:
315                    H_candidates = H_candidates.subs(u, 0)
316                return H_candidates, xi_vars
317            raise ValueError("Symbolic Fenchel conjugate not found; use method='fenchel_numeric' for numeric computation.")
318
319        # FENCHEL: numeric path
320        if method == "fenchel_numeric":
321            if fenchel_opts is None:
322                fenchel_opts = {}
323            if dim == 1:
324                p_bounds = fenchel_opts.get("p_bounds", (-10.0, 10.0))
325                n_grid = int(fenchel_opts.get("n_grid", 2001))
326                mode = fenchel_opts.get("mode", "auto")
327                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 8))
328
329                # Build numeric L_func (try lambdify)
330                try:
331                    f_lamb = sp.lambdify((p_vars[0],), L_expr, "numpy")
332                    def L_func_scalar(p):
333                        return float(f_lamb(p))
334                except Exception:
335                    try:
336                        f_lamb = sp.lambdify(p_vars[0], L_expr, "numpy")
337                        def L_func_scalar(p):
338                            return float(f_lamb(p))
339                    except Exception:
340                        def L_func_scalar(p):
341                            return float(sp.N(L_expr.subs({p_vars[0]: p})))
342
343                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_1d_numeric_callable(
344                    L_func_scalar, p_bounds=p_bounds, n_grid=n_grid, mode=mode,
345                    scipy_multistart=scipy_multistart
346                )
347                H_func = sp.Function("H_numeric")
348                H_repr = H_func(xi_vars[0])
349                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
350                return H_repr, xi_vars, H_numeric
351
352            else:
353                # dim == 2
354                p_bounds = fenchel_opts.get("p_bounds", [(-10.0, 10.0), (-10.0, 10.0)])
355                n_grid_per_dim = int(fenchel_opts.get("n_grid_per_dim", 41))
356                mode = fenchel_opts.get("mode", "auto")
357                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 20))
358                multistart_restarts = int(fenchel_opts.get("multistart_restarts", 8))
359
360                f_lamb = None
361                try:
362                    f_lamb = sp.lambdify((p_vars[0], p_vars[1]), L_expr, "numpy")
363                    def L_func_nd(p):
364                        return float(f_lamb(float(p[0]), float(p[1])))
365                except Exception:
366                    try:
367                        f_lamb = sp.lambdify((p_vars,), L_expr, "numpy")
368                        def L_func_nd(p):
369                            return float(f_lamb(tuple(float(v) for v in p)))
370                    except Exception:
371                        def L_func_nd(p):
372                            subs_map = {p_vars[i]: float(p[i]) for i in range(2)}
373                            return float(sp.N(L_expr.subs(subs_map)))
374
375                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_nd_numeric_callable(
376                    L_func_nd, dim=2, p_bounds=(p_bounds[0], p_bounds[1]),
377                    n_grid_per_dim=n_grid_per_dim, mode=mode,
378                    scipy_multistart=scipy_multistart, multistart_restarts=multistart_restarts
379                )
380                H_func = sp.Function("H_numeric")
381                H_repr = H_func(*xi_vars)
382                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
383                return H_repr, xi_vars, H_numeric
384
385        raise ValueError("Unknown method '{}'. Choose 'legendre', 'fenchel_symbolic' or 'fenchel_numeric'.".format(method))
386
387    @staticmethod
388    def H_to_L(H_expr, coords, u, xi_vars, force=False):
389        """
390        Inverse Legendre (classical). Does not attempt Fenchel inverse.
391        """
392        dim = len(coords)
393        if dim == 1:
394            p_vars = (sp.Symbol('p', real=True),)
395        elif dim == 2:
396            p_vars = (sp.Symbol('p_x', real=True), sp.Symbol('p_y', real=True))
397        else:
398            raise ValueError("Only 1D and 2D are supported.")
399
400        eqs = [sp.Eq(sp.diff(H_expr, xi_vars[i]), p_vars[i]) for i in range(dim)]
401        sol = sp.solve(eqs, xi_vars, dict=True)
402        if not sol:
403            if not force:
404                raise ValueError("Unable to symbolically solve p = ∂H/∂ξ for ξ. Use force=True.")
405            sol = sp.solve(eqs, xi_vars)
406        if not sol:
407            raise ValueError("Inverse Legendre transform failed; cannot find ξ(p).")
408        sol = sol[0] if isinstance(sol, list) else sol
409        if isinstance(sol, tuple) and len(sol) == len(xi_vars):
410            sol = {xi_vars[i]: sol[i] for i in range(len(xi_vars))}
411        if not isinstance(sol, dict):
412            if isinstance(sol, list) and sol and isinstance(sol[0], dict):
413                sol = sol[0]
414            else:
415                raise ValueError("Unexpected output from solve(); cannot construct ξ(p).")
416        L_expr = sum(sol[xi_vars[i]] * p_vars[i] for i in range(dim)) - H_expr.subs(sol)
417        return sp.simplify(L_expr), p_vars

Bidirectional converter between Lagrangian and Hamiltonian (Legendre transform), with optional Legendre–Fenchel (convex conjugate) support and robust numeric fallback.

Main API: L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False, method="legendre", fenchel_opts=None)

- method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
- If method == "fenchel_numeric" returns (H_repr, xi_vars, numeric_callable)
  otherwise returns (H_expr, xi_vars)
@staticmethod
def L_to_H( L_expr, coords, u, p_vars, return_symbol_only=False, force=False, method='legendre', fenchel_opts=None):
231    @staticmethod
232    def L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False,
233               method="legendre", fenchel_opts=None):
234        """
235        Convert L(x,u,p) -> H(x,u,xi) with options for generalized Legendre (Fenchel).
236
237        Parameters:
238          - method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
239          - fenchel_opts: dict with options for numeric fenchel
240        """
241        dim = len(coords)
242        if dim == 1:
243            xi_vars = (sp.Symbol('xi', real=True),)
244        elif dim == 2:
245            xi_vars = (sp.Symbol('xi', real=True), sp.Symbol('eta', real=True))
246        else:
247            raise ValueError("Only 1D and 2D dimensions are supported.")
248
249        # Quadratic fast-path (symbolic)
250        if method in ("legendre", "fenchel_symbolic") and LagrangianHamiltonianConverter._is_quadratic_in_p(L_expr, p_vars):
251            try:
252                H_expr, sol = LagrangianHamiltonianConverter._quadratic_legendre(L_expr, p_vars, xi_vars)
253                if return_symbol_only:
254                    H_expr = H_expr.subs(u, 0)
255                return H_expr, xi_vars
256            except Exception:
257                if not force and method == "legendre":
258                    raise
259
260        # CLASSICAL LEGENDRE
261        if method == "legendre":
262            H_p = None
263            try:
264                H_p = sp.hessian(L_expr, p_vars)
265                det_H = sp.simplify(H_p.det())
266            except Exception:
267                det_H = None
268
269            if det_H is not None and det_H == 0 and not force:
270                raise ValueError("Legendre transform not invertible: Hessian singular. Use force=True or Fenchel method.")
271            if det_H is None and not force:
272                raise ValueError("Unable to verify Hessian determinant symbolically. Use force=True to attempt solve().")
273
274            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
275            sol_list = sp.solve(eqs, p_vars, dict=True)
276            if not sol_list:
277                if not force:
278                    raise ValueError("Unable to solve symbolic Legendre relations. Use force=True or Fenchel fallback.")
279            if sol_list:
280                sol = sol_list[0]
281                if isinstance(sol, tuple) and len(sol) == len(p_vars):
282                    sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
283                H_expr = sum(xi_vars[i]*sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
284                H_expr = sp.simplify(H_expr)
285                if return_symbol_only:
286                    H_expr = H_expr.subs(u, 0)
287                return H_expr, xi_vars
288            raise ValueError("Legendre inversion failed even with solve().")
289
290        # FENCHEL: symbolic attempt
291        # -----------------------------------------------------
292        #  Prevent symbolic Fenchel when L is non-differentiable
293        # -----------------------------------------------------
294        if method == "fenchel_symbolic":
295            if L_expr.has(sp.Abs) or L_expr.has(sp.sign) or any(
296                sp.diff(L_expr, p).has(sp.sign, sp.Abs) for p in p_vars
297            ):
298                raise ValueError(
299                    "Symbolic Fenchel not possible for nonsmooth L (Abs, sign). "
300                    "Use method='fenchel_numeric' instead."
301                )
302
303        if method == "fenchel_symbolic":
304            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
305            sol_list = sp.solve(eqs, p_vars, dict=True)
306            if sol_list:
307                candidates = []
308                for sol in sol_list:
309                    if isinstance(sol, tuple) and len(sol) == len(p_vars):
310                        sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
311                    S_expr = sum(xi_vars[i] * sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
312                    candidates.append(sp.simplify(S_expr))
313                H_candidates = sp.simplify(sp.Max(*candidates)) if len(candidates) > 1 else candidates[0]
314                if return_symbol_only:
315                    H_candidates = H_candidates.subs(u, 0)
316                return H_candidates, xi_vars
317            raise ValueError("Symbolic Fenchel conjugate not found; use method='fenchel_numeric' for numeric computation.")
318
319        # FENCHEL: numeric path
320        if method == "fenchel_numeric":
321            if fenchel_opts is None:
322                fenchel_opts = {}
323            if dim == 1:
324                p_bounds = fenchel_opts.get("p_bounds", (-10.0, 10.0))
325                n_grid = int(fenchel_opts.get("n_grid", 2001))
326                mode = fenchel_opts.get("mode", "auto")
327                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 8))
328
329                # Build numeric L_func (try lambdify)
330                try:
331                    f_lamb = sp.lambdify((p_vars[0],), L_expr, "numpy")
332                    def L_func_scalar(p):
333                        return float(f_lamb(p))
334                except Exception:
335                    try:
336                        f_lamb = sp.lambdify(p_vars[0], L_expr, "numpy")
337                        def L_func_scalar(p):
338                            return float(f_lamb(p))
339                    except Exception:
340                        def L_func_scalar(p):
341                            return float(sp.N(L_expr.subs({p_vars[0]: p})))
342
343                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_1d_numeric_callable(
344                    L_func_scalar, p_bounds=p_bounds, n_grid=n_grid, mode=mode,
345                    scipy_multistart=scipy_multistart
346                )
347                H_func = sp.Function("H_numeric")
348                H_repr = H_func(xi_vars[0])
349                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
350                return H_repr, xi_vars, H_numeric
351
352            else:
353                # dim == 2
354                p_bounds = fenchel_opts.get("p_bounds", [(-10.0, 10.0), (-10.0, 10.0)])
355                n_grid_per_dim = int(fenchel_opts.get("n_grid_per_dim", 41))
356                mode = fenchel_opts.get("mode", "auto")
357                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 20))
358                multistart_restarts = int(fenchel_opts.get("multistart_restarts", 8))
359
360                f_lamb = None
361                try:
362                    f_lamb = sp.lambdify((p_vars[0], p_vars[1]), L_expr, "numpy")
363                    def L_func_nd(p):
364                        return float(f_lamb(float(p[0]), float(p[1])))
365                except Exception:
366                    try:
367                        f_lamb = sp.lambdify((p_vars,), L_expr, "numpy")
368                        def L_func_nd(p):
369                            return float(f_lamb(tuple(float(v) for v in p)))
370                    except Exception:
371                        def L_func_nd(p):
372                            subs_map = {p_vars[i]: float(p[i]) for i in range(2)}
373                            return float(sp.N(L_expr.subs(subs_map)))
374
375                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_nd_numeric_callable(
376                    L_func_nd, dim=2, p_bounds=(p_bounds[0], p_bounds[1]),
377                    n_grid_per_dim=n_grid_per_dim, mode=mode,
378                    scipy_multistart=scipy_multistart, multistart_restarts=multistart_restarts
379                )
380                H_func = sp.Function("H_numeric")
381                H_repr = H_func(*xi_vars)
382                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
383                return H_repr, xi_vars, H_numeric
384
385        raise ValueError("Unknown method '{}'. Choose 'legendre', 'fenchel_symbolic' or 'fenchel_numeric'.".format(method))

Convert L(x,u,p) -> H(x,u,xi) with options for generalized Legendre (Fenchel).

Parameters:

  • method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
  • fenchel_opts: dict with options for numeric fenchel
@staticmethod
def H_to_L(H_expr, coords, u, xi_vars, force=False):
387    @staticmethod
388    def H_to_L(H_expr, coords, u, xi_vars, force=False):
389        """
390        Inverse Legendre (classical). Does not attempt Fenchel inverse.
391        """
392        dim = len(coords)
393        if dim == 1:
394            p_vars = (sp.Symbol('p', real=True),)
395        elif dim == 2:
396            p_vars = (sp.Symbol('p_x', real=True), sp.Symbol('p_y', real=True))
397        else:
398            raise ValueError("Only 1D and 2D are supported.")
399
400        eqs = [sp.Eq(sp.diff(H_expr, xi_vars[i]), p_vars[i]) for i in range(dim)]
401        sol = sp.solve(eqs, xi_vars, dict=True)
402        if not sol:
403            if not force:
404                raise ValueError("Unable to symbolically solve p = ∂H/∂ξ for ξ. Use force=True.")
405            sol = sp.solve(eqs, xi_vars)
406        if not sol:
407            raise ValueError("Inverse Legendre transform failed; cannot find ξ(p).")
408        sol = sol[0] if isinstance(sol, list) else sol
409        if isinstance(sol, tuple) and len(sol) == len(xi_vars):
410            sol = {xi_vars[i]: sol[i] for i in range(len(xi_vars))}
411        if not isinstance(sol, dict):
412            if isinstance(sol, list) and sol and isinstance(sol[0], dict):
413                sol = sol[0]
414            else:
415                raise ValueError("Unexpected output from solve(); cannot construct ξ(p).")
416        L_expr = sum(sol[xi_vars[i]] * p_vars[i] for i in range(dim)) - H_expr.subs(sol)
417        return sp.simplify(L_expr), p_vars

Inverse Legendre (classical). Does not attempt Fenchel inverse.

class HamiltonianSymbolicConverter:
423class HamiltonianSymbolicConverter:
424    """
425    Symbolic converter between Hamiltonians and formal PDEs (psiOp).
426    """
427
428    @staticmethod
429    def decompose_hamiltonian(H_expr, xi_vars):
430        """
431        Decomposes the Hamiltonian into polynomial (local) and non-polynomial (nonlocal) parts.
432        The heuristic treats terms containing sqrt, Abs, or sign as nonlocal.
433        """
434        xi = xi_vars if isinstance(xi_vars, (tuple, list)) else (xi_vars,)
435        poly_terms, nonlocal_terms = 0, 0
436        H_expand = sp.expand(H_expr)
437        for term in H_expand.as_ordered_terms():
438            # Heuristic: treat terms containing sqrt/Abs/sign as nonlocal explicitly
439            # Check if the *current* 'term' (from the outer loop) has these functions.
440            # The original code had a scoping bug in the 'any' statement.
441            if any(func in term.free_symbols for func in [sp.sqrt, sp.Abs, sp.sign]) or \
442               term.has(sp.sqrt) or term.has(sp.Abs) or term.has(sp.sign):
443                # Alternative and more robust check:
444                # This checks if the specific 'term' object contains the specified functions.
445                nonlocal_terms += term
446            elif all(term.is_polynomial(xi_i) for xi_i in xi):
447                poly_terms += term
448            else:
449                nonlocal_terms += term
450        return sp.simplify(poly_terms), sp.simplify(nonlocal_terms)
451
452    @classmethod
453    def hamiltonian_to_symbolic_pde(cls, H_expr, coords, t, u, mode="schrodinger"):
454        dim = len(coords)
455        if dim == 1:
456            xi_vars = (sp.Symbol("xi", real=True),)
457        elif dim == 2:
458            xi_vars = (sp.Symbol("xi", real=True), sp.Symbol("eta", real=True))
459        else:
460            raise ValueError("Only 1D and 2D Hamiltonians are supported.")
461
462        H_poly, H_nonlocal = cls.decompose_hamiltonian(H_expr, xi_vars)
463        H_total = H_poly + H_nonlocal
464        psiOp_H_u = sp.Function("psiOp")(H_total, u)
465
466        if mode == "stationary":
467            E = sp.Symbol("E", real=True)
468            pde = sp.Eq(psiOp_H_u, E * u)
469            formal = "ψOp(H, u) = E u"
470        elif mode == "schrodinger":
471            pde = sp.Eq(sp.I * sp.Derivative(u, t), psiOp_H_u)
472            formal = "i ∂_t u = ψOp(H, u)"
473        elif mode == "wave":
474            pde = sp.Eq(sp.Derivative(u, (t, 2)), -psiOp_H_u)
475            formal = "∂_{tt} u + ψOp(H, u) = 0"
476        else:
477            raise ValueError("mode must be one of: 'stationary', 'schrodinger', 'wave'.")
478
479        coord_str = ", ".join(str(c) for c in coords)
480        xi_str = ", ".join(str(x) for x in xi_vars)
481        formal += f"   (H = H({coord_str}; {xi_str}))"
482
483        return {
484            "pde": sp.simplify(pde),
485            "H_poly": H_poly,
486            "H_nonlocal": H_nonlocal,
487            "formal_string": formal,
488            "mode": mode
489        }

Symbolic converter between Hamiltonians and formal PDEs (psiOp).

@staticmethod
def decompose_hamiltonian(H_expr, xi_vars):
428    @staticmethod
429    def decompose_hamiltonian(H_expr, xi_vars):
430        """
431        Decomposes the Hamiltonian into polynomial (local) and non-polynomial (nonlocal) parts.
432        The heuristic treats terms containing sqrt, Abs, or sign as nonlocal.
433        """
434        xi = xi_vars if isinstance(xi_vars, (tuple, list)) else (xi_vars,)
435        poly_terms, nonlocal_terms = 0, 0
436        H_expand = sp.expand(H_expr)
437        for term in H_expand.as_ordered_terms():
438            # Heuristic: treat terms containing sqrt/Abs/sign as nonlocal explicitly
439            # Check if the *current* 'term' (from the outer loop) has these functions.
440            # The original code had a scoping bug in the 'any' statement.
441            if any(func in term.free_symbols for func in [sp.sqrt, sp.Abs, sp.sign]) or \
442               term.has(sp.sqrt) or term.has(sp.Abs) or term.has(sp.sign):
443                # Alternative and more robust check:
444                # This checks if the specific 'term' object contains the specified functions.
445                nonlocal_terms += term
446            elif all(term.is_polynomial(xi_i) for xi_i in xi):
447                poly_terms += term
448            else:
449                nonlocal_terms += term
450        return sp.simplify(poly_terms), sp.simplify(nonlocal_terms)

Decomposes the Hamiltonian into polynomial (local) and non-polynomial (nonlocal) parts. The heuristic treats terms containing sqrt, Abs, or sign as nonlocal.

@classmethod
def hamiltonian_to_symbolic_pde(cls, H_expr, coords, t, u, mode='schrodinger'):
452    @classmethod
453    def hamiltonian_to_symbolic_pde(cls, H_expr, coords, t, u, mode="schrodinger"):
454        dim = len(coords)
455        if dim == 1:
456            xi_vars = (sp.Symbol("xi", real=True),)
457        elif dim == 2:
458            xi_vars = (sp.Symbol("xi", real=True), sp.Symbol("eta", real=True))
459        else:
460            raise ValueError("Only 1D and 2D Hamiltonians are supported.")
461
462        H_poly, H_nonlocal = cls.decompose_hamiltonian(H_expr, xi_vars)
463        H_total = H_poly + H_nonlocal
464        psiOp_H_u = sp.Function("psiOp")(H_total, u)
465
466        if mode == "stationary":
467            E = sp.Symbol("E", real=True)
468            pde = sp.Eq(psiOp_H_u, E * u)
469            formal = "ψOp(H, u) = E u"
470        elif mode == "schrodinger":
471            pde = sp.Eq(sp.I * sp.Derivative(u, t), psiOp_H_u)
472            formal = "i ∂_t u = ψOp(H, u)"
473        elif mode == "wave":
474            pde = sp.Eq(sp.Derivative(u, (t, 2)), -psiOp_H_u)
475            formal = "∂_{tt} u + ψOp(H, u) = 0"
476        else:
477            raise ValueError("mode must be one of: 'stationary', 'schrodinger', 'wave'.")
478
479        coord_str = ", ".join(str(c) for c in coords)
480        xi_str = ", ".join(str(x) for x in xi_vars)
481        formal += f"   (H = H({coord_str}; {xi_str}))"
482
483        return {
484            "pde": sp.simplify(pde),
485            "H_poly": H_poly,
486            "H_nonlocal": H_nonlocal,
487            "formal_string": formal,
488            "mode": mode
489        }
class SymbolGeometry:
101class SymbolGeometry:
102    """
103    Analyzes the geometric structure of a symbol H(x, ξ)
104    
105    This class computes:
106    - Hamiltonian flow (geodesics)
107    - Jacobian (focusing)
108    - Caustics (singularities)
109    - Periodic orbits
110    - Semiclassical spectrum
111    """
112    
113    def __init__(self, symbol: sp.Expr, x_sym: sp.Symbol, xi_sym: sp.Symbol):
114        """
115        Initialize with a symbolic Hamiltonian
116        
117        Parameters
118        ----------
119        symbol : sympy expression
120            The Hamiltonian H(x, ξ)
121        x_sym, xi_sym : sympy symbols
122            Position and momentum variables
123        """
124        self.H = symbol
125        self.x_sym = x_sym
126        self.xi_sym = xi_sym
127        
128        # Compute derivatives symbolically (DRY principle)
129        self._compute_derivatives()
130        
131        # Convert to numerical functions (cached)
132        self._lambdify_functions()
133    
134    def _compute_derivatives(self):
135        """Compute all necessary derivatives (DRY)"""
136        dH_x = sp.diff(self.H, self.x_sym)
137        self.dH_dx = _sanitize(dH_x)
138        dH_xi = sp.diff(self.H, self.xi_sym)
139        self.dH_dxi = _sanitize(dH_xi)
140        d2H_x2 = sp.diff(self.dH_dx, self.x_sym)
141        self.d2H_dx2 = _sanitize(d2H_x2)        
142        d2H_xi2 = sp.diff(self.dH_dxi, self.xi_sym)
143        self.d2H_dxi2 = _sanitize(d2H_xi2)        
144        d2H_xxi = sp.diff(self.dH_dx, self.xi_sym)
145        self.d2H_dxdxi = _sanitize(d2H_xxi)
146    
147    def _lambdify_functions(self):
148        """Convert symbolic expressions to numerical functions (DRY)"""
149        vars_tuple = (self.x_sym, self.xi_sym)
150        
151        self.f_H = sp.lambdify(vars_tuple, self.H, 'numpy')
152        self.f_dH_dx = sp.lambdify(vars_tuple, self.dH_dx, 'numpy')
153        self.f_dH_dxi = sp.lambdify(vars_tuple, self.dH_dxi, 'numpy')
154        self.f_d2H_dx2 = sp.lambdify(vars_tuple, self.d2H_dx2, 'numpy')
155        self.f_d2H_dxi2 = sp.lambdify(vars_tuple, self.d2H_dxi2, 'numpy')
156        self.f_d2H_dxdxi = sp.lambdify(vars_tuple, self.d2H_dxdxi, 'numpy')
157    
158    def compute_geodesic(self, x0: float, xi0: float, t_max: float, 
159                        n_points: int = 500) -> Geodesic:
160        """
161        Compute geodesic with Jacobian (for caustics detection)
162        
163        Solves the augmented system:
164        dx/dt = ∂H/∂ξ
165        dξ/dt = -∂H/∂x
166        dJ/dt = ∂²H/∂ξ² J + ∂²H/∂x∂ξ K  (variational equation)
167        dK/dt = -∂²H/∂x∂ξ J - ∂²H/∂x² K
168        
169        Parameters
170        ----------
171        x0, xi0 : float
172            Initial conditions
173        t_max : float
174            Final time
175        n_points : int
176            Number of points
177            
178        Returns
179        -------
180        Geodesic
181            Complete geodesic information
182        """
183        def system(t, z):
184            x, xi, J, K = z
185            try:
186                # Hamilton equations
187                dx = float(self.f_dH_dxi(x, xi))
188                dxi = float(-self.f_dH_dx(x, xi))
189                
190                # Variational equations (Jacobian evolution)
191                d2H_dxi2 = float(self.f_d2H_dxi2(x, xi))
192                d2H_dxdxi = float(self.f_d2H_dxdxi(x, xi))
193                d2H_dx2 = float(self.f_d2H_dx2(x, xi))
194                
195                dJ = d2H_dxi2 * J + d2H_dxdxi * K
196                dK = -d2H_dxdxi * J - d2H_dx2 * K
197                
198                return [dx, dxi, dJ, dK]
199            except:
200                return [0, 0, 0, 0]
201        
202        # Initial conditions: J(0)=0, K(0)=1 (standard initial condition)
203        z0 = [x0, xi0, 0.0, 1.0]
204        
205        sol = solve_ivp(
206            system, [0, t_max], z0,
207            t_eval=np.linspace(0, t_max, n_points),
208            method='DOP853',
209            rtol=1e-10, atol=1e-12
210        )
211        
212        # Compute energy along trajectory
213        H_traj = np.array([self.f_H(sol.y[0][i], sol.y[1][i]) 
214                          for i in range(len(sol.t))])
215        
216        return Geodesic(
217            t=sol.t,
218            x=sol.y[0],
219            xi=sol.y[1],
220            H=H_traj,
221            J=sol.y[2],
222            K=sol.y[3]
223        )
224    
225    def find_periodic_orbits(self, energy: float, 
226                            x_range: Tuple[float, float],
227                            xi_range: Tuple[float, float],
228                            n_attempts: int = 50,
229                            tol_period: float = 1e-3) -> List[PeriodicOrbit]:
230        """
231        Find periodic orbits at fixed energy
232        
233        Strategy: Sample energy surface H(x,ξ)=E and look for closed orbits
234        
235        Parameters
236        ----------
237        energy : float
238            Target energy level
239        x_range, xi_range : tuple
240            Search domain
241        n_attempts : int
242            Number of initial conditions to try
243        tol_period : float
244            Tolerance for periodicity detection
245            
246        Returns
247        -------
248        list of PeriodicOrbit
249            Found periodic orbits
250        """
251        orbits = []
252        x_samples = np.linspace(x_range[0], x_range[1], int(np.sqrt(n_attempts)))
253        
254        for x0_test in x_samples:
255            # Solve H(x0, ξ0) = E for ξ0
256            def energy_eq(xi0):
257                try:
258                    return self.f_H(x0_test, xi0) - energy
259                except:
260                    return 1e10
261            
262            xi_guesses = np.linspace(xi_range[0], xi_range[1], 5)
263            
264            for xi_guess in xi_guesses:
265                try:
266                    result = fsolve(energy_eq, xi_guess, full_output=True)
267                    
268                    if result[2] != 1:  # Check convergence
269                        continue
270                    
271                    xi0 = result[0][0]
272                    
273                    # Verify we're on energy surface
274                    if abs(self.f_H(x0_test, xi0) - energy) > 1e-6:
275                        continue
276                    
277                    # Integrate to detect periodicity
278                    T_max = 20
279                    geo = self.compute_geodesic(x0_test, xi0, T_max, 2000)
280                    
281                    # Find returns to initial point
282                    distances = np.sqrt((geo.x - x0_test)**2 + (geo.xi - xi0)**2)
283                    
284                    # Find local minima (except t=0)
285                    minima_idx = []
286                    for i in range(10, len(distances)-10):
287                        if (distances[i] < distances[i-1] and 
288                            distances[i] < distances[i+1] and
289                            distances[i] < tol_period):
290                            minima_idx.append(i)
291                    
292                    if minima_idx:
293                        idx_period = minima_idx[0]
294                        period = geo.t[idx_period]
295                        
296                        if period > 0.1 and distances[idx_period] < tol_period:
297                            # Compute action S = ∮ ξ dx
298                            x_cycle = geo.x[:idx_period+1]
299                            xi_cycle = geo.xi[:idx_period+1]
300                            t_cycle = geo.t[:idx_period+1]
301                            
302                            dx_dt = np.gradient(x_cycle, t_cycle)
303                            action = np.trapz(xi_cycle * dx_dt, t_cycle)
304                            
305                            # Compute stability (Lyapunov exponent)
306                            stability = self._compute_stability(x0_test, xi0, period)
307                            
308                            orbits.append(PeriodicOrbit(
309                                x0=x0_test,
310                                xi0=xi0,
311                                period=period,
312                                action=action,
313                                energy=energy,
314                                stability=stability,
315                                x_cycle=x_cycle,
316                                xi_cycle=xi_cycle,
317                                t_cycle=t_cycle
318                            ))
319                
320                except:
321                    continue
322        
323        # Remove duplicates
324        return self._remove_duplicate_orbits(orbits)
325    
326    def _compute_stability(self, x0: float, xi0: float, T: float) -> float:
327        """Compute Lyapunov exponent (orbit stability)"""
328        def linearized_system(t, z):
329            x, xi, dx, dxi = z
330            try:
331                vx = float(self.f_dH_dxi(x, xi))
332                vxi = float(-self.f_dH_dx(x, xi))
333                
334                # Linearization
335                A12 = float(self.f_d2H_dxi2(x, xi))
336                A21 = float(-self.f_d2H_dxdxi(x, xi))
337                
338                ddx = A12 * dxi
339                ddxi = A21 * dx
340                
341                return [vx, vxi, ddx, ddxi]
342            except:
343                return [0, 0, 0, 0]
344        
345        epsilon = 1e-6
346        z0 = [x0, xi0, epsilon, 0]
347        
348        sol = solve_ivp(linearized_system, [0, T], z0, method='DOP853', rtol=1e-10)
349        
350        if sol.success and len(sol.y[2]) > 0:
351            perturbation_final = np.sqrt(sol.y[2][-1]**2 + sol.y[3][-1]**2)
352            return np.log(perturbation_final / epsilon) / T
353        else:
354            return np.nan
355    
356    def _remove_duplicate_orbits(self, orbits: List[PeriodicOrbit]) -> List[PeriodicOrbit]:
357        """Remove duplicate periodic orbits"""
358        unique = []
359        for orb in orbits:
360            is_duplicate = False
361            for orb_unique in unique:
362                if (abs(orb.period - orb_unique.period) < 0.1 and
363                    abs(orb.action - orb_unique.action) < 0.1):
364                    is_duplicate = True
365                    break
366            if not is_duplicate:
367                unique.append(orb)
368        return unique
369    
370    def gutzwiller_trace_formula(self, periodic_orbits: List[PeriodicOrbit],
371                                 t_values: np.ndarray, hbar: float = 1.0) -> np.ndarray:
372        """
373        Gutzwiller trace formula (semiclassical)
374        
375        Tr[exp(-iHt/ℏ)] ≈ Σ_γ A_γ exp(iS_γ/ℏ - iπμ_γ/2)
376        
377        Parameters
378        ----------
379        periodic_orbits : list
380            List of periodic orbits
381        t_values : array
382            Time values
383        hbar : float
384            Reduced Planck constant
385            
386        Returns
387        -------
388        array
389            Trace as function of time
390        """
391        trace = np.zeros(len(t_values), dtype=complex)
392        
393        for orb in periodic_orbits:
394            T = orb.period
395            S = orb.action
396            lambda_stab = orb.stability
397            
398            # Include repetitions of the orbit
399            for k in range(1, 5):
400                T_k = k * T
401                S_k = k * S
402                
403                # Stability factor
404                if not np.isnan(lambda_stab):
405                    det_factor = abs(2 * np.sinh(k * lambda_stab * T))
406                else:
407                    det_factor = 1
408                
409                if det_factor > 1e-10:
410                    amplitude = T_k / np.sqrt(det_factor)
411                    
412                    # Maslov index (simplified: 0)
413                    mu = 0
414                    
415                    phase = S_k / hbar - np.pi * mu / 2
416                    contribution = amplitude * np.exp(1j * phase) * np.sinc((t_values - T_k) / T_k)
417                    trace += contribution
418        
419        return trace
420    
421    def semiclassical_spectrum(self, periodic_orbits: List[PeriodicOrbit],
422                              hbar: float = 1.0, resolution: int = 1000) -> Spectrum:
423        """
424        Extract semiclassical spectrum via Fourier transform of trace
425        
426        Parameters
427        ----------
428        periodic_orbits : list
429            Periodic orbits
430        hbar : float
431            Reduced Planck constant
432        resolution : int
433            Number of points
434            
435        Returns
436        -------
437        Spectrum
438            Spectral information
439        """
440        t_max = 50 / hbar
441        t_values = np.linspace(0, t_max, resolution)
442        
443        trace = self.gutzwiller_trace_formula(periodic_orbits, t_values, hbar)
444        
445        # Fourier transform: t → E
446        energies_fft = fftfreq(len(t_values), d=t_values[1]-t_values[0]) * 2 * np.pi * hbar
447        spectrum_fft = fft(trace)
448        
449        return Spectrum(
450            energies=energies_fft,
451            intensity=np.abs(spectrum_fft),
452            trace_t=t_values,
453            trace=trace
454        )

Analyzes the geometric structure of a symbol H(x, ξ)

This class computes:

  • Hamiltonian flow (geodesics)
  • Jacobian (focusing)
  • Caustics (singularities)
  • Periodic orbits
  • Semiclassical spectrum
SymbolGeometry( symbol: sympy.core.expr.Expr, x_sym: sympy.core.symbol.Symbol, xi_sym: sympy.core.symbol.Symbol)
113    def __init__(self, symbol: sp.Expr, x_sym: sp.Symbol, xi_sym: sp.Symbol):
114        """
115        Initialize with a symbolic Hamiltonian
116        
117        Parameters
118        ----------
119        symbol : sympy expression
120            The Hamiltonian H(x, ξ)
121        x_sym, xi_sym : sympy symbols
122            Position and momentum variables
123        """
124        self.H = symbol
125        self.x_sym = x_sym
126        self.xi_sym = xi_sym
127        
128        # Compute derivatives symbolically (DRY principle)
129        self._compute_derivatives()
130        
131        # Convert to numerical functions (cached)
132        self._lambdify_functions()

Initialize with a symbolic Hamiltonian

Parameters

symbol : sympy expression The Hamiltonian H(x, ξ) x_sym, xi_sym : sympy symbols Position and momentum variables

H
x_sym
xi_sym
def compute_geodesic( self, x0: float, xi0: float, t_max: float, n_points: int = 500) -> src.geometry_1d.Geodesic:
158    def compute_geodesic(self, x0: float, xi0: float, t_max: float, 
159                        n_points: int = 500) -> Geodesic:
160        """
161        Compute geodesic with Jacobian (for caustics detection)
162        
163        Solves the augmented system:
164        dx/dt = ∂H/∂ξ
165        dξ/dt = -∂H/∂x
166        dJ/dt = ∂²H/∂ξ² J + ∂²H/∂x∂ξ K  (variational equation)
167        dK/dt = -∂²H/∂x∂ξ J - ∂²H/∂x² K
168        
169        Parameters
170        ----------
171        x0, xi0 : float
172            Initial conditions
173        t_max : float
174            Final time
175        n_points : int
176            Number of points
177            
178        Returns
179        -------
180        Geodesic
181            Complete geodesic information
182        """
183        def system(t, z):
184            x, xi, J, K = z
185            try:
186                # Hamilton equations
187                dx = float(self.f_dH_dxi(x, xi))
188                dxi = float(-self.f_dH_dx(x, xi))
189                
190                # Variational equations (Jacobian evolution)
191                d2H_dxi2 = float(self.f_d2H_dxi2(x, xi))
192                d2H_dxdxi = float(self.f_d2H_dxdxi(x, xi))
193                d2H_dx2 = float(self.f_d2H_dx2(x, xi))
194                
195                dJ = d2H_dxi2 * J + d2H_dxdxi * K
196                dK = -d2H_dxdxi * J - d2H_dx2 * K
197                
198                return [dx, dxi, dJ, dK]
199            except:
200                return [0, 0, 0, 0]
201        
202        # Initial conditions: J(0)=0, K(0)=1 (standard initial condition)
203        z0 = [x0, xi0, 0.0, 1.0]
204        
205        sol = solve_ivp(
206            system, [0, t_max], z0,
207            t_eval=np.linspace(0, t_max, n_points),
208            method='DOP853',
209            rtol=1e-10, atol=1e-12
210        )
211        
212        # Compute energy along trajectory
213        H_traj = np.array([self.f_H(sol.y[0][i], sol.y[1][i]) 
214                          for i in range(len(sol.t))])
215        
216        return Geodesic(
217            t=sol.t,
218            x=sol.y[0],
219            xi=sol.y[1],
220            H=H_traj,
221            J=sol.y[2],
222            K=sol.y[3]
223        )

Compute geodesic with Jacobian (for caustics detection)

Solves the augmented system: dx/dt = ∂H/∂ξ dξ/dt = -∂H/∂x dJ/dt = ∂²H/∂ξ² J + ∂²H/∂x∂ξ K (variational equation) dK/dt = -∂²H/∂x∂ξ J - ∂²H/∂x² K

Parameters

x0, xi0 : float Initial conditions t_max : float Final time n_points : int Number of points

Returns

Geodesic Complete geodesic information

def find_periodic_orbits( self, energy: float, x_range: Tuple[float, float], xi_range: Tuple[float, float], n_attempts: int = 50, tol_period: float = 0.001) -> List[src.geometry_1d.PeriodicOrbit]:
225    def find_periodic_orbits(self, energy: float, 
226                            x_range: Tuple[float, float],
227                            xi_range: Tuple[float, float],
228                            n_attempts: int = 50,
229                            tol_period: float = 1e-3) -> List[PeriodicOrbit]:
230        """
231        Find periodic orbits at fixed energy
232        
233        Strategy: Sample energy surface H(x,ξ)=E and look for closed orbits
234        
235        Parameters
236        ----------
237        energy : float
238            Target energy level
239        x_range, xi_range : tuple
240            Search domain
241        n_attempts : int
242            Number of initial conditions to try
243        tol_period : float
244            Tolerance for periodicity detection
245            
246        Returns
247        -------
248        list of PeriodicOrbit
249            Found periodic orbits
250        """
251        orbits = []
252        x_samples = np.linspace(x_range[0], x_range[1], int(np.sqrt(n_attempts)))
253        
254        for x0_test in x_samples:
255            # Solve H(x0, ξ0) = E for ξ0
256            def energy_eq(xi0):
257                try:
258                    return self.f_H(x0_test, xi0) - energy
259                except:
260                    return 1e10
261            
262            xi_guesses = np.linspace(xi_range[0], xi_range[1], 5)
263            
264            for xi_guess in xi_guesses:
265                try:
266                    result = fsolve(energy_eq, xi_guess, full_output=True)
267                    
268                    if result[2] != 1:  # Check convergence
269                        continue
270                    
271                    xi0 = result[0][0]
272                    
273                    # Verify we're on energy surface
274                    if abs(self.f_H(x0_test, xi0) - energy) > 1e-6:
275                        continue
276                    
277                    # Integrate to detect periodicity
278                    T_max = 20
279                    geo = self.compute_geodesic(x0_test, xi0, T_max, 2000)
280                    
281                    # Find returns to initial point
282                    distances = np.sqrt((geo.x - x0_test)**2 + (geo.xi - xi0)**2)
283                    
284                    # Find local minima (except t=0)
285                    minima_idx = []
286                    for i in range(10, len(distances)-10):
287                        if (distances[i] < distances[i-1] and 
288                            distances[i] < distances[i+1] and
289                            distances[i] < tol_period):
290                            minima_idx.append(i)
291                    
292                    if minima_idx:
293                        idx_period = minima_idx[0]
294                        period = geo.t[idx_period]
295                        
296                        if period > 0.1 and distances[idx_period] < tol_period:
297                            # Compute action S = ∮ ξ dx
298                            x_cycle = geo.x[:idx_period+1]
299                            xi_cycle = geo.xi[:idx_period+1]
300                            t_cycle = geo.t[:idx_period+1]
301                            
302                            dx_dt = np.gradient(x_cycle, t_cycle)
303                            action = np.trapz(xi_cycle * dx_dt, t_cycle)
304                            
305                            # Compute stability (Lyapunov exponent)
306                            stability = self._compute_stability(x0_test, xi0, period)
307                            
308                            orbits.append(PeriodicOrbit(
309                                x0=x0_test,
310                                xi0=xi0,
311                                period=period,
312                                action=action,
313                                energy=energy,
314                                stability=stability,
315                                x_cycle=x_cycle,
316                                xi_cycle=xi_cycle,
317                                t_cycle=t_cycle
318                            ))
319                
320                except:
321                    continue
322        
323        # Remove duplicates
324        return self._remove_duplicate_orbits(orbits)

Find periodic orbits at fixed energy

Strategy: Sample energy surface H(x,ξ)=E and look for closed orbits

Parameters

energy : float Target energy level x_range, xi_range : tuple Search domain n_attempts : int Number of initial conditions to try tol_period : float Tolerance for periodicity detection

Returns

list of PeriodicOrbit Found periodic orbits

def gutzwiller_trace_formula( self, periodic_orbits: List[src.geometry_1d.PeriodicOrbit], t_values: numpy.ndarray, hbar: float = 1.0) -> numpy.ndarray:
370    def gutzwiller_trace_formula(self, periodic_orbits: List[PeriodicOrbit],
371                                 t_values: np.ndarray, hbar: float = 1.0) -> np.ndarray:
372        """
373        Gutzwiller trace formula (semiclassical)
374        
375        Tr[exp(-iHt/ℏ)] ≈ Σ_γ A_γ exp(iS_γ/ℏ - iπμ_γ/2)
376        
377        Parameters
378        ----------
379        periodic_orbits : list
380            List of periodic orbits
381        t_values : array
382            Time values
383        hbar : float
384            Reduced Planck constant
385            
386        Returns
387        -------
388        array
389            Trace as function of time
390        """
391        trace = np.zeros(len(t_values), dtype=complex)
392        
393        for orb in periodic_orbits:
394            T = orb.period
395            S = orb.action
396            lambda_stab = orb.stability
397            
398            # Include repetitions of the orbit
399            for k in range(1, 5):
400                T_k = k * T
401                S_k = k * S
402                
403                # Stability factor
404                if not np.isnan(lambda_stab):
405                    det_factor = abs(2 * np.sinh(k * lambda_stab * T))
406                else:
407                    det_factor = 1
408                
409                if det_factor > 1e-10:
410                    amplitude = T_k / np.sqrt(det_factor)
411                    
412                    # Maslov index (simplified: 0)
413                    mu = 0
414                    
415                    phase = S_k / hbar - np.pi * mu / 2
416                    contribution = amplitude * np.exp(1j * phase) * np.sinc((t_values - T_k) / T_k)
417                    trace += contribution
418        
419        return trace

Gutzwiller trace formula (semiclassical)

Tr[exp(-iHt/ℏ)] ≈ Σ_γ A_γ exp(iS_γ/ℏ - iπμ_γ/2)

Parameters

periodic_orbits : list List of periodic orbits t_values : array Time values hbar : float Reduced Planck constant

Returns

array Trace as function of time

def semiclassical_spectrum( self, periodic_orbits: List[src.geometry_1d.PeriodicOrbit], hbar: float = 1.0, resolution: int = 1000) -> src.geometry_1d.Spectrum:
421    def semiclassical_spectrum(self, periodic_orbits: List[PeriodicOrbit],
422                              hbar: float = 1.0, resolution: int = 1000) -> Spectrum:
423        """
424        Extract semiclassical spectrum via Fourier transform of trace
425        
426        Parameters
427        ----------
428        periodic_orbits : list
429            Periodic orbits
430        hbar : float
431            Reduced Planck constant
432        resolution : int
433            Number of points
434            
435        Returns
436        -------
437        Spectrum
438            Spectral information
439        """
440        t_max = 50 / hbar
441        t_values = np.linspace(0, t_max, resolution)
442        
443        trace = self.gutzwiller_trace_formula(periodic_orbits, t_values, hbar)
444        
445        # Fourier transform: t → E
446        energies_fft = fftfreq(len(t_values), d=t_values[1]-t_values[0]) * 2 * np.pi * hbar
447        spectrum_fft = fft(trace)
448        
449        return Spectrum(
450            energies=energies_fft,
451            intensity=np.abs(spectrum_fft),
452            trace_t=t_values,
453            trace=trace
454        )

Extract semiclassical spectrum via Fourier transform of trace

Parameters

periodic_orbits : list Periodic orbits hbar : float Reduced Planck constant resolution : int Number of points

Returns

Spectrum Spectral information

class SymbolVisualizer:
461class SymbolVisualizer:
462    """
463    Comprehensive visualization of symbol geometry
464    
465    Produces 15 panels showing:
466    1. Hamiltonian surface (3D)
467    2. Energy level sets (phase space foliation)
468    3. Hamiltonian vector field
469    4. Group velocity ∂H/∂ξ
470    5. Spatial projection (caustics)
471    6. Jacobian (focusing measure)
472    7. Curvature (focusing tendency)
473    8. Energy conservation
474    9. Periodic orbits (phase space)
475    10. Period-energy diagram
476    11. EBK quantization
477    12. Trace formula
478    13. Semiclassical spectrum
479    14. Orbit stability
480    15. Level spacing distribution
481    """
482    
483    def __init__(self, geometry: SymbolGeometry):
484        """
485        Parameters
486        ----------
487        geometry : SymbolGeometry
488            Initialized geometry engine
489        """
490        self.geo = geometry
491    
492    def visualize_complete(self, 
493                          x_range: Tuple[float, float],
494                          xi_range: Tuple[float, float],
495                          geodesics_params: List[Tuple],
496                          E_range: Optional[Tuple[float, float]] = None,
497                          hbar: float = 1.0,
498                          resolution: int = 100) -> Tuple:
499        """
500        Create complete geometric atlas
501        
502        Parameters
503        ----------
504        x_range, xi_range : tuple
505            Domain limits
506        geodesics_params : list of tuples
507            Each tuple: (x0, xi0, t_max, color)
508        E_range : tuple, optional
509            Energy range for spectral analysis
510        hbar : float
511            Reduced Planck constant
512        resolution : int
513            Grid resolution
514            
515        Returns
516        -------
517        fig, geodesics, periodic_orbits, spectrum
518        """
519        # Compute grid
520        x_grid = np.linspace(x_range[0], x_range[1], resolution)
521        xi_grid = np.linspace(xi_range[0], xi_range[1], resolution)
522        X, Xi = np.meshgrid(x_grid, xi_grid)
523        
524        # Evaluate Hamiltonian and derivatives on grid
525        grids = self._evaluate_grids(X, Xi)
526        
527        # Compute geodesics
528        geodesics = self._compute_geodesics(geodesics_params)
529        
530        # Find periodic orbits (if E_range specified)
531        periodic_orbits = []
532        spectrum = None
533        if E_range:
534            energies = np.linspace(E_range[0], E_range[1], 8)
535            for E in energies:
536                orbits = self.geo.find_periodic_orbits(E, x_range, xi_range)
537                periodic_orbits.extend(orbits)
538            
539            if periodic_orbits:
540                spectrum = self.geo.semiclassical_spectrum(periodic_orbits, hbar)
541        
542        # Create figure
543        fig = self._create_figure(X, Xi, grids, geodesics, periodic_orbits, spectrum, hbar)
544        
545        return fig, geodesics, periodic_orbits, spectrum
546    
547    def _evaluate_grids(self, X: np.ndarray, Xi: np.ndarray) -> Dict:
548        """Evaluate all necessary fields on grid (DRY)"""
549        grids = {}
550        
551        for name, func in [
552            ('H', self.geo.f_H),
553            ('dH_dxi', self.geo.f_dH_dxi),
554            ('dH_dx', self.geo.f_dH_dx),
555            ('d2H_dxdxi', self.geo.f_d2H_dxdxi)
556        ]:
557            grid = np.zeros_like(X)
558            for i in range(X.shape[0]):
559                for j in range(X.shape[1]):
560                    try:
561                        grid[i, j] = func(X[i, j], Xi[i, j])
562                    except:
563                        grid[i, j] = np.nan
564            grids[name] = grid
565        
566        return grids
567    
568    def _compute_geodesics(self, params: List[Tuple]) -> List[Geodesic]:
569        """Compute all geodesics"""
570        geodesics = []
571        for p in params:
572            x0, xi0, t_max = p[:3]
573            geo = self.geo.compute_geodesic(x0, xi0, t_max)
574            geo.color = p[3] if len(p) > 3 else 'blue'
575            geodesics.append(geo)
576        return geodesics
577    
578    def _create_figure(self, X, Xi, grids, geodesics, periodic_orbits, spectrum, hbar):
579        """Create the complete visualization figure"""
580        fig = plt.figure(figsize=(24, 18))
581        
582        # Panel 1-8: Geometry
583        self._plot_hamiltonian_surface(fig, X, Xi, grids['H'], geodesics, 1)
584        self._plot_level_sets(fig, X, Xi, grids['H'], geodesics, 2)
585        self._plot_vector_field(fig, X, Xi, grids, geodesics, 3)
586        self._plot_group_velocity(fig, X, Xi, grids['dH_dxi'], geodesics, 4)
587        self._plot_spatial_projection(fig, geodesics, 5)
588        self._plot_jacobian(fig, geodesics, 6)
589        self._plot_curvature(fig, X, Xi, grids['d2H_dxdxi'], geodesics, 7)
590        self._plot_energy_conservation(fig, geodesics, 8)
591        
592        # Panel 9-15: Spectral analysis
593        if periodic_orbits:
594            self._plot_periodic_orbits(fig, X, Xi, grids['H'], periodic_orbits, 9)
595            self._plot_period_energy(fig, periodic_orbits, 10)
596            self._plot_ebk_quantization(fig, periodic_orbits, hbar, 11)
597            
598            if spectrum:
599                self._plot_trace_formula(fig, spectrum, 12)
600                self._plot_spectrum(fig, spectrum, 13)
601                self._plot_stability(fig, periodic_orbits, 14)
602                self._plot_level_spacing(fig, spectrum, 15)
603        
604        plt.suptitle(f'Geometric and Semiclassical Atlas: H = {self.geo.H}',
605                     fontsize=18, fontweight='bold', y=0.995)
606        plt.tight_layout(rect=[0, 0, 1, 0.98])
607
608        
609        return fig
610    
611    # Individual plotting methods (KISS principle: each does one thing)
612    
613    def _plot_hamiltonian_surface(self, fig, X, Xi, H_grid, geodesics, panel):
614        """Panel 1: Hamiltonian surface in 3D"""
615        ax = fig.add_subplot(3, 5, panel, projection='3d')
616        ax.plot_surface(X, Xi, H_grid, cmap='viridis', alpha=0.8, 
617                        linewidth=0, antialiased=True)
618        
619        for geo in geodesics:
620            color = getattr(geo, 'color', 'red')
621            ax.plot(geo.x, geo.xi, geo.H, color=color, linewidth=3)
622            ax.scatter([geo.x[0]], [geo.xi[0]], [geo.H[0]], 
623                       color=color, s=100, edgecolors='black', linewidths=2)
624        
625        ax.set_xlabel('x')
626        ax.set_ylabel('ξ')
627        ax.set_zlabel('H(x,ξ)')
628        ax.set_title('Hamiltonian Surface\n+ Geodesics', fontweight='bold')
629        ax.view_init(elev=25, azim=45)
630        
631        # 🔧 Ajustements pour taille cohérente
632        ax.set_box_aspect((1, 1, 0.6))   # équilibre visuel (x, ξ, H)
633        ax.margins(0)                    # supprime marges internes
634        ax.set_proj_type('ortho')        # projection orthographique = moins de distorsion
635    
636    def _plot_level_sets(self, fig, X, Xi, H_grid, geodesics, panel):
637        """Panel 2: Energy level sets (symplectic foliation)"""
638        ax = fig.add_subplot(3, 5, panel)
639        levels = np.linspace(np.nanmin(H_grid), np.nanmax(H_grid), 20)
640        contour = ax.contour(X, Xi, H_grid, levels=levels, cmap='viridis')
641        ax.clabel(contour, inline=True, fontsize=8)
642        
643        for geo in geodesics:
644            color = getattr(geo, 'color', 'red')
645            ax.plot(geo.x, geo.xi, color=color, linewidth=2.5)
646        
647        ax.set_xlabel('x')
648        ax.set_ylabel('ξ')
649        ax.set_title('Level Sets H=const\nSymplectic Foliation', fontweight='bold')
650        ax.grid(True, alpha=0.3)
651        ax.set_aspect('auto')     
652        ax.margins(0.05)          
653    
654    
655    def _plot_vector_field(self, fig, X, Xi, grids, geodesics, panel):
656        """Panel 3: Hamiltonian vector field"""
657        ax = fig.add_subplot(3, 5, panel)
658        
659        step = max(1, X.shape[0] // 20)
660        X_sub = X[::step, ::step]
661        Xi_sub = Xi[::step, ::step]
662        vx = grids['dH_dxi'][::step, ::step]
663        vy = -grids['dH_dx'][::step, ::step]
664        
665        magnitude = np.sqrt(vx**2 + vy**2)
666        magnitude[magnitude == 0] = 1
667        
668        ax.quiver(X_sub, Xi_sub, vx/magnitude, vy/magnitude,
669                 magnitude, cmap='plasma', alpha=0.7)
670        
671        for geo in geodesics:
672            color = getattr(geo, 'color', 'cyan')
673            ax.plot(geo.x, geo.xi, color=color, linewidth=3)
674        
675        ax.set_xlabel('x')
676        ax.set_ylabel('ξ')
677        ax.set_title('Hamiltonian Vector Field\n(Infinitesimal generator)', fontweight='bold')
678        ax.grid(True, alpha=0.3)
679    
680    def _plot_group_velocity(self, fig, X, Xi, dH_dxi, geodesics, panel):
681        """Panel 4: Group velocity ∂H/∂ξ"""
682        ax = fig.add_subplot(3, 5, panel)
683        
684        im = ax.contourf(X, Xi, dH_dxi, levels=30, cmap='RdBu_r')
685        plt.colorbar(im, ax=ax, label='∂H/∂ξ')
686        ax.contour(X, Xi, dH_dxi, levels=[0], colors='black', 
687                  linewidths=2, linestyles='--')
688        
689        for geo in geodesics:
690            ax.plot(geo.x, geo.xi, color='yellow', linewidth=2)
691        
692        ax.set_xlabel('x')
693        ax.set_ylabel('ξ')
694        ax.set_title('Group Velocity v_g = ∂H/∂ξ\n(Wave propagation speed)', fontweight='bold')
695        ax.grid(True, alpha=0.3)
696    
697    def _plot_spatial_projection(self, fig, geodesics, panel):
698        """Panel 5: Spatial projection (with caustics)"""
699        ax = fig.add_subplot(3, 5, panel)
700        
701        for geo in geodesics:
702            color = getattr(geo, 'color', 'blue')
703            ax.plot(geo.x, geo.t, color=color, linewidth=2.5)
704            
705            # Mark caustics
706            caust_idx = geo.caustics
707            if len(caust_idx) > 0:
708                ax.scatter(geo.x[caust_idx], geo.t[caust_idx],
709                          color='red', s=150, marker='*', zorder=15,
710                          edgecolors='darkred', linewidths=1.5)
711        
712        ax.set_xlabel('x')
713        ax.set_ylabel('t')
714        ax.set_title('Spatial Projection\n★ = Caustics', fontweight='bold')
715        ax.grid(True, alpha=0.3)
716    
717    def _plot_jacobian(self, fig, geodesics, panel):
718        """Panel 6: Jacobian (focusing measure)"""
719        ax = fig.add_subplot(3, 5, panel)
720        
721        for geo in geodesics:
722            color = getattr(geo, 'color', 'blue')
723            ax.plot(geo.t, geo.J, color=color, linewidth=2.5)
724        
725        ax.axhline(0, color='red', linestyle='--', linewidth=2, alpha=0.7)
726        ax.set_xlabel('t')
727        ax.set_ylabel('J = ∂x/∂ξ₀')
728        ax.set_title('Jacobian (Focusing)\nJ→0: rays converge', fontweight='bold')
729        ax.grid(True, alpha=0.3)
730    
731    def _plot_curvature(self, fig, X, Xi, curvature, geodesics, panel):
732        """Panel 7: Sectional curvature"""
733        ax = fig.add_subplot(3, 5, panel)
734        
735        im = ax.contourf(X, Xi, curvature, levels=30, cmap='seismic')
736        plt.colorbar(im, ax=ax, label='∂²H/∂x∂ξ')
737        
738        for geo in geodesics:
739            ax.plot(geo.x, geo.xi, color='lime', linewidth=2)
740        
741        ax.set_xlabel('x')
742        ax.set_ylabel('ξ')
743        ax.set_title('Sectional Curvature\nRed>0: focusing | Blue<0: defocusing', fontweight='bold')
744        ax.grid(True, alpha=0.3)
745    
746    def _plot_energy_conservation(self, fig, geodesics, panel):
747        """Panel 8: Energy conservation (integration quality)"""
748        ax = fig.add_subplot(3, 5, panel)
749        
750        for geo in geodesics:
751            color = getattr(geo, 'color', 'blue')
752            H_variation = (geo.H - geo.H[0]) / (np.abs(geo.H[0]) + 1e-10)
753            ax.semilogy(geo.t, np.abs(H_variation) + 1e-16,
754                       color=color, linewidth=2.5, label=f'E₀={geo.H[0]:.2f}')
755        
756        ax.set_xlabel('t')
757        ax.set_ylabel('|ΔH/H₀|')
758        ax.set_title('Energy Conservation\n(Numerical quality)', fontweight='bold')
759        ax.legend(fontsize=9)
760        ax.grid(True, alpha=0.3, which='both')
761    
762    def _plot_periodic_orbits(self, fig, X, Xi, H_grid, periodic_orbits, panel):
763        """Panel 9: Periodic orbits in phase space"""
764        ax = fig.add_subplot(3, 5, panel)
765        
766        # Energy level sets
767        energies = np.unique([orb.energy for orb in periodic_orbits])
768        contour = ax.contour(X, Xi, H_grid, levels=energies, 
769                            cmap='viridis', linewidths=1.5, alpha=0.6)
770        
771        # Periodic orbits
772        colors_orb = plt.cm.rainbow(np.linspace(0, 1, len(periodic_orbits)))
773        for idx, orb in enumerate(periodic_orbits):
774            ax.plot(orb.x_cycle, orb.xi_cycle, 
775                   color=colors_orb[idx], linewidth=3, alpha=0.8)
776            ax.scatter([orb.x0], [orb.xi0], color=colors_orb[idx], 
777                      s=100, marker='o', edgecolors='black', linewidths=2, zorder=10)
778        
779        ax.set_xlabel('x')
780        ax.set_ylabel('ξ')
781        ax.set_title('Periodic Orbits\n(Phase space)', fontweight='bold')
782        ax.grid(True, alpha=0.3)
783        ax.set_aspect('equal')
784    
785    def _plot_period_energy(self, fig, periodic_orbits, panel):
786        """Panel 10: Period-Energy relation"""
787        ax = fig.add_subplot(3, 5, panel)
788        
789        E_orb = [orb.energy for orb in periodic_orbits]
790        T_orb = [orb.period for orb in periodic_orbits]
791        S_orb = [orb.action for orb in periodic_orbits]
792        
793        scatter = ax.scatter(E_orb, T_orb, c=S_orb, s=150,
794                           cmap='plasma', edgecolors='black', linewidths=1.5)
795        plt.colorbar(scatter, ax=ax, label='Action S')
796        
797        ax.set_xlabel('Energy E')
798        ax.set_ylabel('Period T')
799        ax.set_title('Period-Energy Diagram\nT(E)', fontweight='bold')
800        ax.grid(True, alpha=0.3)
801    
802    def _plot_ebk_quantization(self, fig, periodic_orbits, hbar, panel):
803        """Panel 11: EBK quantization (Einstein-Brillouin-Keller)"""
804        ax = fig.add_subplot(3, 5, panel)
805        
806        E_orb = [orb.energy for orb in periodic_orbits]
807        S_orb = [orb.action for orb in periodic_orbits]
808        T_orb = [orb.period for orb in periodic_orbits]
809        
810        scatter = ax.scatter(E_orb, S_orb, s=150, c=T_orb, cmap='cool',
811                           edgecolors='black', linewidths=1.5)
812        plt.colorbar(scatter, ax=ax, label='Period T')
813        
814        # EBK quantization rules: S = 2πℏ(n + α)
815        E_max = max(E_orb) if E_orb else 10
816        for n in range(15):
817            S_quant = 2 * np.pi * hbar * (n + 0.25)  # α ≈ 1/4 for 1D
818            if S_quant < max(S_orb) if S_orb else 10:
819                ax.axhline(S_quant, color='red', linestyle='--', alpha=0.3, linewidth=1)
820                ax.text(min(E_orb) if E_orb else 0, S_quant, f'n={n}',
821                       fontsize=8, color='red', va='bottom')
822        
823        ax.set_xlabel('Energy E')
824        ax.set_ylabel('Action S')
825        ax.set_title('EBK Quantization\nS = 2πℏ(n+α)', fontweight='bold')
826        ax.grid(True, alpha=0.3)
827    
828    def _plot_trace_formula(self, fig, spectrum, panel):
829        """Panel 12: Gutzwiller trace formula"""
830        ax = fig.add_subplot(3, 5, panel)
831        
832        # Plot only first part for clarity
833        n_plot = min(500, len(spectrum.trace_t))
834        ax.plot(spectrum.trace_t[:n_plot], np.real(spectrum.trace[:n_plot]),
835               'b-', linewidth=1.5, label='Re[Tr]')
836        ax.plot(spectrum.trace_t[:n_plot], np.imag(spectrum.trace[:n_plot]),
837               'r-', linewidth=1.5, alpha=0.7, label='Im[Tr]')
838        
839        ax.set_xlabel('Time t')
840        ax.set_ylabel('Tr[exp(-iHt/ℏ)]')
841        ax.set_title('Gutzwiller Trace Formula\nΣ_γ A_γ exp(iS_γ/ℏ)', fontweight='bold')
842        ax.legend()
843        ax.grid(True, alpha=0.3)
844    
845    def _plot_spectrum(self, fig, spectrum, panel):
846        """Panel 13: Semiclassical spectrum"""
847        ax = fig.add_subplot(3, 5, panel)
848        
849        # Only positive energies
850        mask = spectrum.energies > 0
851        E_positive = spectrum.energies[mask]
852        I_positive = spectrum.intensity[mask]
853        
854        # Detect peaks
855        peaks, properties = find_peaks(I_positive, 
856                                      height=np.max(I_positive)*0.1,
857                                      distance=20)
858        
859        ax.plot(E_positive, I_positive, 'b-', linewidth=1.5)
860        ax.plot(E_positive[peaks], I_positive[peaks],
861               'ro', markersize=10, label='Energy levels')
862        
863        # Annotate first levels
864        for i, peak in enumerate(peaks[:10]):
865            E_level = E_positive[peak]
866            ax.text(E_level, I_positive[peak], f'E_{i}',
867                   fontsize=9, ha='center', va='bottom')
868        
869        ax.set_xlabel('Energy E')
870        ax.set_ylabel('Spectral density')
871        ax.set_title('Semiclassical Spectrum\n(Fourier transform of trace)', fontweight='bold')
872        ax.legend()
873        ax.grid(True, alpha=0.3)
874    
875    def _plot_stability(self, fig, periodic_orbits, panel):
876        """Panel 14: Orbit stability (Lyapunov exponents)"""
877        ax = fig.add_subplot(3, 5, panel)
878        
879        stab = [orb.stability for orb in periodic_orbits]
880        E_stab = [orb.energy for orb in periodic_orbits]
881        T_stab = [orb.period for orb in periodic_orbits]
882        
883        scatter = ax.scatter(E_stab, stab, s=150, c=T_stab, cmap='autumn',
884                           edgecolors='black', linewidths=1.5)
885        plt.colorbar(scatter, ax=ax, label='Period T')
886        ax.axhline(0, color='green', linestyle='--', linewidth=2,
887                  label='Marginal stability')
888        
889        ax.set_xlabel('Energy E')
890        ax.set_ylabel('Lyapunov exponent λ')
891        ax.set_title('Orbit Stability\nλ>0: unstable | λ<0: stable', fontweight='bold')
892        ax.legend()
893        ax.grid(True, alpha=0.3)
894    
895    def _plot_level_spacing(self, fig, spectrum, panel):
896        """Panel 15: Level spacing distribution (integrability test)"""
897        ax = fig.add_subplot(3, 5, panel)
898        
899        # Extract energy levels
900        mask = spectrum.energies > 0
901        E_positive = spectrum.energies[mask]
902        I_positive = spectrum.intensity[mask]
903        
904        peaks, _ = find_peaks(I_positive, height=np.max(I_positive)*0.1)
905        
906        if len(peaks) > 1:
907            E_levels = E_positive[peaks]
908            spacings = np.diff(E_levels)
909            
910            # Normalize spacings
911            s_mean = np.mean(spacings)
912            s_normalized = spacings / s_mean
913            
914            # Histogram
915            ax.hist(s_normalized, bins=20, density=True, alpha=0.7,
916                   color='blue', edgecolor='black', label='Data')
917            
918            # Theoretical distributions
919            s = np.linspace(0, np.max(s_normalized), 100)
920            
921            # Poisson (integrable systems)
922            poisson = np.exp(-s)
923            ax.plot(s, poisson, 'g--', linewidth=2, label='Poisson (integrable)')
924            
925            # Wigner (chaotic systems)
926            wigner = (np.pi * s / 2) * np.exp(-np.pi * s**2 / 4)
927            ax.plot(s, wigner, 'r-', linewidth=2, label='Wigner (chaotic)')
928            
929            ax.set_xlabel('Normalized spacing s')
930            ax.set_ylabel('P(s)')
931            ax.set_title('Level Spacing Distribution\nIntegrable vs Chaotic', fontweight='bold')
932            ax.legend()
933            ax.grid(True, alpha=0.3)

Comprehensive visualization of symbol geometry

Produces 15 panels showing:

  1. Hamiltonian surface (3D)
  2. Energy level sets (phase space foliation)
  3. Hamiltonian vector field
  4. Group velocity ∂H/∂ξ
  5. Spatial projection (caustics)
  6. Jacobian (focusing measure)
  7. Curvature (focusing tendency)
  8. Energy conservation
  9. Periodic orbits (phase space)
  10. Period-energy diagram
  11. EBK quantization
  12. Trace formula
  13. Semiclassical spectrum
  14. Orbit stability
  15. Level spacing distribution
SymbolVisualizer(geometry: SymbolGeometry)
483    def __init__(self, geometry: SymbolGeometry):
484        """
485        Parameters
486        ----------
487        geometry : SymbolGeometry
488            Initialized geometry engine
489        """
490        self.geo = geometry

Parameters

geometry : SymbolGeometry Initialized geometry engine

geo
def visualize_complete( self, x_range: Tuple[float, float], xi_range: Tuple[float, float], geodesics_params: List[Tuple], E_range: Optional[Tuple[float, float]] = None, hbar: float = 1.0, resolution: int = 100) -> Tuple:
492    def visualize_complete(self, 
493                          x_range: Tuple[float, float],
494                          xi_range: Tuple[float, float],
495                          geodesics_params: List[Tuple],
496                          E_range: Optional[Tuple[float, float]] = None,
497                          hbar: float = 1.0,
498                          resolution: int = 100) -> Tuple:
499        """
500        Create complete geometric atlas
501        
502        Parameters
503        ----------
504        x_range, xi_range : tuple
505            Domain limits
506        geodesics_params : list of tuples
507            Each tuple: (x0, xi0, t_max, color)
508        E_range : tuple, optional
509            Energy range for spectral analysis
510        hbar : float
511            Reduced Planck constant
512        resolution : int
513            Grid resolution
514            
515        Returns
516        -------
517        fig, geodesics, periodic_orbits, spectrum
518        """
519        # Compute grid
520        x_grid = np.linspace(x_range[0], x_range[1], resolution)
521        xi_grid = np.linspace(xi_range[0], xi_range[1], resolution)
522        X, Xi = np.meshgrid(x_grid, xi_grid)
523        
524        # Evaluate Hamiltonian and derivatives on grid
525        grids = self._evaluate_grids(X, Xi)
526        
527        # Compute geodesics
528        geodesics = self._compute_geodesics(geodesics_params)
529        
530        # Find periodic orbits (if E_range specified)
531        periodic_orbits = []
532        spectrum = None
533        if E_range:
534            energies = np.linspace(E_range[0], E_range[1], 8)
535            for E in energies:
536                orbits = self.geo.find_periodic_orbits(E, x_range, xi_range)
537                periodic_orbits.extend(orbits)
538            
539            if periodic_orbits:
540                spectrum = self.geo.semiclassical_spectrum(periodic_orbits, hbar)
541        
542        # Create figure
543        fig = self._create_figure(X, Xi, grids, geodesics, periodic_orbits, spectrum, hbar)
544        
545        return fig, geodesics, periodic_orbits, spectrum

Create complete geometric atlas

Parameters

x_range, xi_range : tuple Domain limits geodesics_params : list of tuples Each tuple: (x0, xi0, t_max, color) E_range : tuple, optional Energy range for spectral analysis hbar : float Reduced Planck constant resolution : int Grid resolution

Returns

fig, geodesics, periodic_orbits, spectrum

class SpectralAnalysis:
 940class SpectralAnalysis:
 941    """
 942    Additional spectral analysis tools
 943    """
 944    
 945    @staticmethod
 946    def weyl_law(energy: float, dimension: int, hbar: float = 1.0) -> float:
 947        """
 948        Weyl's law: asymptotic density of states
 949        
 950        N(E) ~ (1/2πℏ)^d × Vol{H(x,p) ≤ E}
 951        
 952        Parameters
 953        ----------
 954        energy : float
 955            Energy threshold
 956        dimension : int
 957            Phase space dimension
 958        hbar : float
 959            Reduced Planck constant
 960            
 961        Returns
 962        -------
 963        float
 964            Approximate number of states below energy E
 965        """
 966        # Simplified: assumes phase space volume ~ E^d
 967        prefactor = (1 / (2 * np.pi * hbar)) ** dimension
 968        return prefactor * (energy ** dimension)
 969    
 970    @staticmethod
 971    def analyze_integrability(spacings: np.ndarray) -> Dict:
 972        """
 973        Determine if system is integrable or chaotic via level statistics
 974        
 975        Parameters
 976        ----------
 977        spacings : array
 978            Energy level spacings
 979            
 980        Returns
 981        -------
 982        dict
 983            Statistical measures and classification
 984        """
 985        s_mean = np.mean(spacings)
 986        s_normalized = spacings / s_mean
 987        
 988        # Brody parameter (0: Poisson, 1: Wigner)
 989        # Fit P(s) = a s^β exp(-b s^(β+1))
 990        # Simplified: use ratio test
 991        
 992        # <s²>/<s>² ratio
 993        ratio = np.mean(s_normalized**2) / (np.mean(s_normalized)**2)
 994        
 995        # Poisson: ratio ≈ 2
 996        # Wigner: ratio ≈ 1.27
 997        
 998        if ratio > 1.7:
 999            classification = "Integrable (Poisson-like)"
1000        elif ratio < 1.4:
1001            classification = "Chaotic (Wigner-like)"
1002        else:
1003            classification = "Intermediate"
1004        
1005        return {
1006            'ratio': ratio,
1007            'mean_spacing': s_mean,
1008            'std_spacing': np.std(spacings),
1009            'classification': classification
1010        }
1011    
1012    @staticmethod
1013    def berry_tabor_formula(periodic_orbits: List[PeriodicOrbit], 
1014                           energy: float) -> float:
1015        """
1016        Berry-Tabor formula for integrable systems
1017        
1018        Smoothed density of states from periodic orbits
1019        
1020        Parameters
1021        ----------
1022        periodic_orbits : list
1023            Periodic orbits
1024        energy : float
1025            Energy at which to evaluate density
1026            
1027        Returns
1028        -------
1029        float
1030            Density of states ρ(E)
1031        """
1032        density = 0.0
1033        
1034        for orb in periodic_orbits:
1035            if abs(orb.energy - energy) < 0.1:
1036                # Contribution from this orbit
1037                # ρ(E) ~ T(E) / (2π) for integrable systems
1038                density += orb.period / (2 * np.pi)
1039        
1040        return density

Additional spectral analysis tools

@staticmethod
def weyl_law(energy: float, dimension: int, hbar: float = 1.0) -> float:
945    @staticmethod
946    def weyl_law(energy: float, dimension: int, hbar: float = 1.0) -> float:
947        """
948        Weyl's law: asymptotic density of states
949        
950        N(E) ~ (1/2πℏ)^d × Vol{H(x,p) ≤ E}
951        
952        Parameters
953        ----------
954        energy : float
955            Energy threshold
956        dimension : int
957            Phase space dimension
958        hbar : float
959            Reduced Planck constant
960            
961        Returns
962        -------
963        float
964            Approximate number of states below energy E
965        """
966        # Simplified: assumes phase space volume ~ E^d
967        prefactor = (1 / (2 * np.pi * hbar)) ** dimension
968        return prefactor * (energy ** dimension)

Weyl's law: asymptotic density of states

N(E) ~ (1/2πℏ)^d × Vol{H(x,p) ≤ E}

Parameters

energy : float Energy threshold dimension : int Phase space dimension hbar : float Reduced Planck constant

Returns

float Approximate number of states below energy E

@staticmethod
def analyze_integrability(spacings: numpy.ndarray) -> Dict:
 970    @staticmethod
 971    def analyze_integrability(spacings: np.ndarray) -> Dict:
 972        """
 973        Determine if system is integrable or chaotic via level statistics
 974        
 975        Parameters
 976        ----------
 977        spacings : array
 978            Energy level spacings
 979            
 980        Returns
 981        -------
 982        dict
 983            Statistical measures and classification
 984        """
 985        s_mean = np.mean(spacings)
 986        s_normalized = spacings / s_mean
 987        
 988        # Brody parameter (0: Poisson, 1: Wigner)
 989        # Fit P(s) = a s^β exp(-b s^(β+1))
 990        # Simplified: use ratio test
 991        
 992        # <s²>/<s>² ratio
 993        ratio = np.mean(s_normalized**2) / (np.mean(s_normalized)**2)
 994        
 995        # Poisson: ratio ≈ 2
 996        # Wigner: ratio ≈ 1.27
 997        
 998        if ratio > 1.7:
 999            classification = "Integrable (Poisson-like)"
1000        elif ratio < 1.4:
1001            classification = "Chaotic (Wigner-like)"
1002        else:
1003            classification = "Intermediate"
1004        
1005        return {
1006            'ratio': ratio,
1007            'mean_spacing': s_mean,
1008            'std_spacing': np.std(spacings),
1009            'classification': classification
1010        }

Determine if system is integrable or chaotic via level statistics

Parameters

spacings : array Energy level spacings

Returns

dict Statistical measures and classification

@staticmethod
def berry_tabor_formula( periodic_orbits: List[src.geometry_1d.PeriodicOrbit], energy: float) -> float:
1012    @staticmethod
1013    def berry_tabor_formula(periodic_orbits: List[PeriodicOrbit], 
1014                           energy: float) -> float:
1015        """
1016        Berry-Tabor formula for integrable systems
1017        
1018        Smoothed density of states from periodic orbits
1019        
1020        Parameters
1021        ----------
1022        periodic_orbits : list
1023            Periodic orbits
1024        energy : float
1025            Energy at which to evaluate density
1026            
1027        Returns
1028        -------
1029        float
1030            Density of states ρ(E)
1031        """
1032        density = 0.0
1033        
1034        for orb in periodic_orbits:
1035            if abs(orb.energy - energy) < 0.1:
1036                # Contribution from this orbit
1037                # ρ(E) ~ T(E) / (2π) for integrable systems
1038                density += orb.period / (2 * np.pi)
1039        
1040        return density

Berry-Tabor formula for integrable systems

Smoothed density of states from periodic orbits

Parameters

periodic_orbits : list Periodic orbits energy : float Energy at which to evaluate density

Returns

float Density of states ρ(E)

class SymbolGeometry2D:
140class SymbolGeometry2D:
141    """
142    Full geometric and semi-classical analysis of a 2D symbol
143    H(x, y, ξ, η) with 4D phase space and rigorous caustic treatment
144    """
145    def __init__(self, symbol: sp.Expr, 
146                 x_sym: sp.Symbol, y_sym: sp.Symbol,
147                 xi_sym: sp.Symbol, eta_sym: sp.Symbol,
148                 hbar: float = 1.0):
149        """
150        Initialization with complete derivative computation for Jacobian evolution
151        Parameters
152        ----------
153        symbol : sympy expression
154            Hamiltonian H(x, y, ξ, η)
155        x_sym, y_sym : sympy symbols
156            Position coordinates
157        xi_sym, eta_sym : sympy symbols
158            Momentum coordinates
159        hbar : float
160            Reduced Planck constant (for quantum aspects)
161        """
162        self.H_sym = symbol
163        self.x_sym = x_sym
164        self.y_sym = y_sym
165        self.xi_sym = xi_sym
166        self.eta_sym = eta_sym
167        self.hbar = hbar
168            
169        print(f"Initializing 2D geometry engine for H = {self.H_sym} with ℏ = {self.hbar}")
170        # --- First derivatives (Hamiltonian vector field) ---
171        dH_x = sp.diff(self.H_sym, self.x_sym)
172        self.dH_dx_sym = _sanitize(dH_x)
173        dH_y = sp.diff(self.H_sym, self.y_sym)
174        self.dH_dy_sym = _sanitize(dH_y)
175        dH_xi = sp.diff(self.H_sym, self.xi_sym)
176        self.dH_dxi_sym = _sanitize(dH_xi)
177        dH_eta = sp.diff(self.H_sym, self.eta_sym)
178        self.dH_deta_sym = _sanitize(dH_eta)
179
180        # --- Second derivatives for variational equations ---
181        d2H_x2 = sp.diff(self.dH_dx_sym, self.x_sym)
182        self.d2H_dx2_sym = _sanitize(d2H_x2)
183        d2H_y2 = sp.diff(self.dH_dy_sym, self.y_sym)
184        self.d2H_dy2_sym = _sanitize(d2H_y2)
185        d2H_xi2 = sp.diff(self.dH_dxi_sym, self.xi_sym)
186        self.d2H_dxi2_sym = _sanitize(d2H_xi2)
187        d2H_eta2 = sp.diff(self.dH_deta_sym, self.eta_sym)
188        self.d2H_deta2_sym = _sanitize(d2H_eta2)
189        d2H_xy = sp.diff(self.dH_dx_sym, self.y_sym)
190        self.d2H_dxdy_sym = _sanitize(d2H_xy)
191        d2H_xxi = sp.diff(self.dH_dx_sym, self.xi_sym)
192        self.d2H_dxdxi_sym = _sanitize(d2H_xxi)
193        d2H_xeta = sp.diff(self.dH_dx_sym, self.eta_sym)
194        self.d2H_dxdeta_sym = _sanitize(d2H_xeta)
195        d2H_yxi = sp.diff(self.dH_dy_sym, self.xi_sym)
196        self.d2H_dydxi_sym = _sanitize(d2H_yxi)
197        d2H_yeta = sp.diff(self.dH_dy_sym, self.eta_sym)
198        self.d2H_dyeta_sym = _sanitize(d2H_yeta)
199        d2H_xieta = sp.diff(self.dH_dxi_sym, self.eta_sym)
200        self.d2H_dxideta_sym = _sanitize(d2H_xieta)
201        # --- Hessian for variational equations ---
202        self.Hessian = sp.Matrix([
203            [self.d2H_dx2_sym, self.d2H_dxdy_sym, self.d2H_dxdxi_sym, self.d2H_dxdeta_sym],
204            [self.d2H_dxdy_sym, self.d2H_dy2_sym, self.d2H_dydxi_sym, self.d2H_dyeta_sym],
205            [self.d2H_dxdxi_sym, self.d2H_dydxi_sym, self.d2H_dxi2_sym, self.d2H_dxideta_sym],
206            [self.d2H_dxdeta_sym, self.d2H_dyeta_sym, self.d2H_dxideta_sym, self.d2H_deta2_sym]
207        ])
208
209        # --- Convert to numerical functions ---
210        self._lambdify_functions()
211  
212    def _safe_lambdify(self, args: tuple, expr: sp.Expr) -> Callable:
213        """Safe conversion of sympy expressions to numerical functions"""
214        if isinstance(expr, (int, float, sp.Integer, sp.Float)):
215            const_val = float(expr)
216            return lambda x, y, xi, eta: np.full_like(x, const_val)
217        try:
218            return sp.lambdify(args, expr, modules=['numpy', 'scipy'])
219        except Exception as e:
220            print(f"Warning: lambdify failed for {expr}. Error: {e}")
221            return lambda x, y, xi, eta: np.full_like(x, np.nan)
222
223    def _lambdify_functions(self):
224        """Convert all symbolic expressions to numerical functions"""
225        args = (self.x_sym, self.y_sym, self.xi_sym, self.eta_sym)
226        self.H_num = self._safe_lambdify(args, self.H_sym)
227        self.dH_dx_num = self._safe_lambdify(args, self.dH_dx_sym)
228        self.dH_dy_num = self._safe_lambdify(args, self.dH_dy_sym)
229        self.dH_dxi_num = self._safe_lambdify(args, self.dH_dxi_sym)
230        self.dH_deta_num = self._safe_lambdify(args, self.dH_deta_sym)
231        # Hessian functions
232        self.second_derivs_funcs = []
233        for i in range(4):
234            row_funcs = []
235            for j in range(4):
236                row_funcs.append(self._safe_lambdify(args, self.Hessian[i,j]))
237            self.second_derivs_funcs.append(row_funcs)
238    
239    def _hamiltonian_system_augmented(self, t: float, z: np.ndarray) -> np.ndarray:
240        """
241        Augmented Hamiltonian system with variational equations for Jacobian evolution
242        State vector z = [x, y, xi, eta, J11, J12, ..., J44] (20 dimensions)
243        """
244        # Extract position and momentum
245        x, y, xi, eta = z[0:4]
246        # Extract Jacobian matrix (4x4)
247        J = z[4:].reshape((4, 4))
248        try:
249            # Hamilton's equations
250            dx = float(self.dH_dxi_num(x, y, xi, eta))
251            dy = float(self.dH_deta_num(x, y, xi, eta))
252            dxi = float(-self.dH_dx_num(x, y, xi, eta))
253            deta = float(-self.dH_dy_num(x, y, xi, eta))
254            # Evaluate numerical Hessian
255            Hessian_num = np.zeros((4, 4))
256            for i in range(4):
257                for j in range(4):
258                    Hessian_num[i, j] = float(self.second_derivs_funcs[i][j](x, y, xi, eta))
259            # Symplectic matrix J0
260            J0 = np.array([
261                [0, 0, 1, 0],
262                [0, 0, 0, 1],
263                [-1, 0, 0, 0],
264                [0, -1, 0, 0]
265            ])
266            # Variational equations: dJ/dt = J @ (J0 @ Hessian)
267            dJ_dt = J @ (J0 @ Hessian_num)
268            # Build derivative vector
269            dz = np.zeros(20)
270            dz[0:4] = [dx, dy, dxi, deta]
271            dz[4:] = dJ_dt.flatten()
272            return dz
273        except Exception as e:
274            print(f"Integration error at t={t}, z={z}: {e}")
275            return np.zeros(20)
276    
277    def compute_geodesic(self, x0: float, y0: float, 
278                        xi0: float, eta0: float,
279                        t_max: float, n_points: int = 500) -> Geodesic2D:
280        """
281        Compute a geodesic with full Jacobian evolution for caustic detection
282        Parameters
283        ----------
284        x0, y0 : float
285            Initial position
286        xi0, eta0 : float
287            Initial momentum
288        t_max : float
289            Final time
290        n_points : int
291            Number of sampling points
292        Returns
293        -------
294        Geodesic2D
295            Structure containing trajectory and caustic analysis
296        """
297        # Initial condition: position, momentum + identity Jacobian
298        z0 = np.zeros(20)
299        z0[0:4] = [x0, y0, xi0, eta0]
300        z0[4:] = np.eye(4).flatten()
301        t_eval = np.linspace(0, t_max, n_points)
302        sol = solve_ivp(
303            self._hamiltonian_system_augmented,
304            [0, t_max], z0, t_eval=t_eval,
305            method='DOP853', rtol=1e-9, atol=1e-12
306        )
307        if not sol.success:
308            print(f"Warning: Integration failed for ({x0}, {y0}, {xi0}, {eta0})")
309        # Extract trajectory data
310        x_traj = sol.y[0]
311        y_traj = sol.y[1]
312        xi_traj = sol.y[2]
313        eta_traj = sol.y[3]
314        # Evaluate energy
315        H_vals = self.H_num(x_traj, y_traj, xi_traj, eta_traj)
316        # Extract and reshape Jacobian matrices
317        J_mats = np.zeros((n_points, 4, 4))
318        for i in range(n_points):
319            J_mats[i] = sol.y[4:, i].reshape((4, 4))
320        # Submatrix for caustic detection: ∂(x,y)/∂(ξ₀,η₀)
321        caustic_matrix = J_mats[:, 0:2, 2:4]
322        # Determinant for caustic detection
323        det_caustic = np.zeros(n_points)
324        for i in range(n_points):
325            det_caustic[i] = np.linalg.det(caustic_matrix[i])
326        # Detect caustic indices (sign change)
327        caustic_indices = np.where(np.diff(np.sign(det_caustic)))[0]
328        return Geodesic2D(
329            t=sol.t,
330            x=x_traj,
331            y=y_traj,
332            xi=xi_traj,
333            eta=eta_traj,
334            H=H_vals,
335            J_full=J_mats,
336            det_caustic=det_caustic,
337            caustic_indices=caustic_indices
338        )
339    
340    def find_periodic_orbits_2d(self, energy: float,
341                               x_range: Tuple[float, float],
342                               y_range: Tuple[float, float],
343                               xi_range: Tuple[float, float],
344                               eta_range: Tuple[float, float],
345                               n_attempts: int = 30) -> List[PeriodicOrbit2D]:
346        """
347        Search for periodic orbits with Maslov index computation
348        """
349        orbits = []
350        # Sample configuration space
351        n_samples = int(np.sqrt(n_attempts))
352        x_samples = np.linspace(x_range[0], x_range[1], n_samples)
353        y_samples = np.linspace(y_range[0], y_range[1], n_samples)
354        for x0 in x_samples:
355            for y0 in y_samples:
356                # Test different momentum directions
357                angles = np.linspace(0, 2*np.pi, 8)
358                for angle in angles:
359                    for r in np.linspace(0.5, 3, 3):
360                        xi0_guess = r * np.cos(angle)
361                        eta0_guess = r * np.sin(angle)
362                        try:
363                            # Energy check
364                            E_test = self.H_num(x0, y0, xi0_guess, eta0_guess)
365                            if abs(E_test - energy) > 0.5:
366                                continue
367                            # Compute geodesic
368                            geo = self.compute_geodesic(x0, y0, xi0_guess, eta0_guess, 15, 1500)
369                            # Search for return points
370                            distances = np.sqrt((geo.x - x0)**2 + (geo.y - y0)**2 +
371                                              (geo.xi - xi0_guess)**2 + (geo.eta - eta0_guess)**2)
372                            minima = []
373                            for i in range(10, len(distances)-10):
374                                if (distances[i] < distances[i-1] and
375                                    distances[i] < distances[i+1] and
376                                    distances[i] < 0.05):
377                                    minima.append(i)
378                            if minima:
379                                idx = minima[0]
380                                period = geo.t[idx]
381                                if period > 0.2 and distances[idx] < 0.05:
382                                    # Compute action
383                                    x_cyc = geo.x[:idx+1]
384                                    y_cyc = geo.y[:idx+1]
385                                    xi_cyc = geo.xi[:idx+1]
386                                    eta_cyc = geo.eta[:idx+1]
387                                    t_cyc = geo.t[:idx+1]
388                                    dx_dt = np.gradient(x_cyc, t_cyc)
389                                    dy_dt = np.gradient(y_cyc, t_cyc)
390                                    action = np.trapz(xi_cyc * dx_dt + eta_cyc * dy_dt, t_cyc)
391                                    # Compute Maslov index (number of caustic crossings)
392                                    maslov_index = len([i for i in geo.caustic_indices if i < idx])
393                                    # Compute stability
394                                    stab1 = self._compute_stability_2d(x0, y0, xi0_guess, eta0_guess, period)
395                                    orbits.append(PeriodicOrbit2D(
396                                        x0=x0, y0=y0,
397                                        xi0=xi0_guess, eta0=eta0_guess,
398                                        period=period,
399                                        action=action,
400                                        energy=energy,
401                                        stability_1=stab1,
402                                        stability_2=0.0,
403                                        x_cycle=x_cyc,
404                                        y_cycle=y_cyc,
405                                        xi_cycle=xi_cyc,
406                                        eta_cycle=eta_cyc,
407                                        t_cycle=t_cyc,
408                                        maslov_index=maslov_index
409                                    ))
410                        except Exception as e:
411                            continue
412        return self._remove_duplicate_orbits_2d(orbits)
413    
414    def _compute_stability_2d(self, x0, y0, xi0, eta0, T):
415        """Compute the largest Lyapunov exponent"""
416        def linearized(t, z):
417            x, y, xi, eta, dx, dy, dxi, deta = z
418            try:
419                vx = float(self.dH_dxi_num(x, y, xi, eta))
420                vy = float(self.dH_deta_num(x, y, xi, eta))
421                vxi = float(-self.dH_dx_num(x, y, xi, eta))
422                veta = float(-self.dH_dy_num(x, y, xi, eta))
423                # Linearization (simplified)
424                A13 = float(self.second_derivs_funcs[2][0](x, y, xi, eta))
425                A24 = float(self.second_derivs_funcs[3][1](x, y, xi, eta))
426                ddx = A13 * dxi
427                ddy = A24 * deta
428                ddxi = 0
429                ddeta = 0
430                return [vx, vy, vxi, veta, ddx, ddy, ddxi, ddeta]
431            except:
432                return [0]*8
433        eps = 1e-6
434        z0 = [x0, y0, xi0, eta0, eps, 0, 0, 0]
435        sol = solve_ivp(linearized, [0, T], z0, method='DOP853', rtol=1e-10)
436        if sol.success and len(sol.y[4]) > 0:
437            pert = np.sqrt(sol.y[4][-1]**2 + sol.y[5][-1]**2)
438            return np.log(pert / eps) / T
439        return np.nan
440    
441    def _remove_duplicate_orbits_2d(self, orbits):
442        """Remove duplicate periodic orbits"""
443        unique = []
444        for orb in orbits:
445            is_dup = False
446            for u_orb in unique:
447                if (abs(orb.period - u_orb.period) < 0.2 and
448                    abs(orb.action - u_orb.action) < 0.2):
449                    is_dup = True
450                    break
451            if not is_dup:
452                unique.append(orb)
453        return unique
454    
455    def detect_caustic_structures(self, geodesics: List[Geodesic2D], 
456                                 t_fixed: float) -> List[CausticStructure]:
457        """
458        Advanced caustic structure detection with classification
459        """
460        caustic_points = []
461        for geo in geodesics:
462            # Find closest time to t_fixed
463            idx = np.argmin(np.abs(geo.t - t_fixed))
464            # Check if near a caustic
465            if abs(geo.det_caustic[idx]) < 0.1:
466                # Classify caustic type
467                caustic_type = self._classify_caustic(geo, idx)
468                # Compute singularity strength
469                strength = 1.0 / (abs(geo.det_caustic[idx]) + 0.01)
470                caustic_points.append({
471                    'x': geo.x[idx],
472                    'y': geo.y[idx],
473                    'energy': geo.energy,
474                    'type': caustic_type,
475                    'strength': strength
476                })
477        if len(caustic_points) < 3:
478            return []
479        # Cluster points into caustic structures
480        caustic_structures = self._cluster_caustic_points(caustic_points, t_fixed)
481        return caustic_structures
482    
483    def _classify_caustic(self, geo: Geodesic2D, idx: int) -> str:
484        """
485        Caustic classification according to catastrophe theory
486        """
487        # Compute curvature near caustic point
488        window = 10
489        start = max(0, idx - window)
490        end = min(len(geo.t), idx + window + 1)
491        if end - start < 5:
492            return 'fold'
493        # Curvature approximation
494        x_window = geo.x[start:end]
495        y_window = geo.y[start:end]
496        dx = np.gradient(x_window)
497        dy = np.gradient(y_window)
498        ddx = np.gradient(dx)
499        ddy = np.gradient(dy)
500        with np.errstate(divide='ignore', invalid='ignore'):
501            curvature = np.abs(dx * ddy - dy * ddx) / (dx**2 + dy**2)**1.5
502        curvature = np.nan_to_num(curvature, nan=0.0, posinf=0.0, neginf=0.0)
503        # Detect cusp points (high curvature)
504        if np.max(curvature) > 2.0 * np.mean(curvature):
505            return 'cusp'
506        return 'fold'
507    
508    def _cluster_caustic_points(self, points: List[dict], t_fixed: float) -> List[CausticStructure]:
509        """Group caustic points into coherent structures"""
510        if not points:
511            return []
512        # Extract coordinates
513        coords = np.array([[p['x'], p['y']] for p in points])
514        # Simple proximity-based clustering
515        clusters = []
516        visited = set()
517        for i, point in enumerate(points):
518            if i in visited:
519                continue
520            # New cluster
521            cluster = [point]
522            visited.add(i)
523            # Find nearby points
524            for j, other in enumerate(points):
525                if j in visited:
526                    continue
527                dist = np.sqrt((point['x'] - other['x'])**2 + (point['y'] - other['y'])**2)
528                if dist < 0.5:  # Distance threshold
529                    cluster.append(other)
530                    visited.add(j)
531            # Create caustic structure
532            xs = np.array([p['x'] for p in cluster])
533            ys = np.array([p['y'] for p in cluster])
534            types = [p['type'] for p in cluster]
535            strengths = [p['strength'] for p in cluster]
536            # Majority type
537            type_counts = {}
538            for t in types:
539                type_counts[t] = type_counts.get(t, 0) + 1
540            dominant_type = max(type_counts.items(), key=lambda x: x[1])[0]
541            # Maslov index (approximation)
542            maslov_index = 1 if dominant_type == 'fold' else 2
543            clusters.append(CausticStructure(
544                x=xs,
545                y=ys,
546                t=t_fixed,
547                energy=cluster[0]['energy'],
548                type=dominant_type,
549                maslov_index=maslov_index,
550                strength=np.mean(strengths)
551            ))
552        return clusters
553    
554    def compute_phase_space_volume(self, E_max: float, x_range: tuple, y_range: tuple,
555                                 xi_range: tuple, eta_range: tuple, 
556                                 n_samples: int = 200000) -> float:
557        """Monte Carlo estimation of phase space volume for H ≤ E_max"""
558        # Generate random samples
559        x_samples = np.random.uniform(x_range[0], x_range[1], n_samples)
560        y_samples = np.random.uniform(y_range[0], y_range[1], n_samples)
561        xi_samples = np.random.uniform(xi_range[0], xi_range[1], n_samples)
562        eta_samples = np.random.uniform(eta_range[0], eta_range[1], n_samples)
563        # Evaluate Hamiltonian
564        H_vals = self.H_num(x_samples, y_samples, xi_samples, eta_samples)
565        # Count points where H ≤ E_max
566        volume_ratio = np.mean(H_vals <= E_max)
567        # Total phase space volume
568        total_volume = ((x_range[1]-x_range[0]) * (y_range[1]-y_range[0]) * 
569                       (xi_range[1]-xi_range[0]) * (eta_range[1]-eta_range[0]))
570        return volume_ratio * total_volume

Full geometric and semi-classical analysis of a 2D symbol H(x, y, ξ, η) with 4D phase space and rigorous caustic treatment

SymbolGeometry2D( symbol: sympy.core.expr.Expr, x_sym: sympy.core.symbol.Symbol, y_sym: sympy.core.symbol.Symbol, xi_sym: sympy.core.symbol.Symbol, eta_sym: sympy.core.symbol.Symbol, hbar: float = 1.0)
145    def __init__(self, symbol: sp.Expr, 
146                 x_sym: sp.Symbol, y_sym: sp.Symbol,
147                 xi_sym: sp.Symbol, eta_sym: sp.Symbol,
148                 hbar: float = 1.0):
149        """
150        Initialization with complete derivative computation for Jacobian evolution
151        Parameters
152        ----------
153        symbol : sympy expression
154            Hamiltonian H(x, y, ξ, η)
155        x_sym, y_sym : sympy symbols
156            Position coordinates
157        xi_sym, eta_sym : sympy symbols
158            Momentum coordinates
159        hbar : float
160            Reduced Planck constant (for quantum aspects)
161        """
162        self.H_sym = symbol
163        self.x_sym = x_sym
164        self.y_sym = y_sym
165        self.xi_sym = xi_sym
166        self.eta_sym = eta_sym
167        self.hbar = hbar
168            
169        print(f"Initializing 2D geometry engine for H = {self.H_sym} with ℏ = {self.hbar}")
170        # --- First derivatives (Hamiltonian vector field) ---
171        dH_x = sp.diff(self.H_sym, self.x_sym)
172        self.dH_dx_sym = _sanitize(dH_x)
173        dH_y = sp.diff(self.H_sym, self.y_sym)
174        self.dH_dy_sym = _sanitize(dH_y)
175        dH_xi = sp.diff(self.H_sym, self.xi_sym)
176        self.dH_dxi_sym = _sanitize(dH_xi)
177        dH_eta = sp.diff(self.H_sym, self.eta_sym)
178        self.dH_deta_sym = _sanitize(dH_eta)
179
180        # --- Second derivatives for variational equations ---
181        d2H_x2 = sp.diff(self.dH_dx_sym, self.x_sym)
182        self.d2H_dx2_sym = _sanitize(d2H_x2)
183        d2H_y2 = sp.diff(self.dH_dy_sym, self.y_sym)
184        self.d2H_dy2_sym = _sanitize(d2H_y2)
185        d2H_xi2 = sp.diff(self.dH_dxi_sym, self.xi_sym)
186        self.d2H_dxi2_sym = _sanitize(d2H_xi2)
187        d2H_eta2 = sp.diff(self.dH_deta_sym, self.eta_sym)
188        self.d2H_deta2_sym = _sanitize(d2H_eta2)
189        d2H_xy = sp.diff(self.dH_dx_sym, self.y_sym)
190        self.d2H_dxdy_sym = _sanitize(d2H_xy)
191        d2H_xxi = sp.diff(self.dH_dx_sym, self.xi_sym)
192        self.d2H_dxdxi_sym = _sanitize(d2H_xxi)
193        d2H_xeta = sp.diff(self.dH_dx_sym, self.eta_sym)
194        self.d2H_dxdeta_sym = _sanitize(d2H_xeta)
195        d2H_yxi = sp.diff(self.dH_dy_sym, self.xi_sym)
196        self.d2H_dydxi_sym = _sanitize(d2H_yxi)
197        d2H_yeta = sp.diff(self.dH_dy_sym, self.eta_sym)
198        self.d2H_dyeta_sym = _sanitize(d2H_yeta)
199        d2H_xieta = sp.diff(self.dH_dxi_sym, self.eta_sym)
200        self.d2H_dxideta_sym = _sanitize(d2H_xieta)
201        # --- Hessian for variational equations ---
202        self.Hessian = sp.Matrix([
203            [self.d2H_dx2_sym, self.d2H_dxdy_sym, self.d2H_dxdxi_sym, self.d2H_dxdeta_sym],
204            [self.d2H_dxdy_sym, self.d2H_dy2_sym, self.d2H_dydxi_sym, self.d2H_dyeta_sym],
205            [self.d2H_dxdxi_sym, self.d2H_dydxi_sym, self.d2H_dxi2_sym, self.d2H_dxideta_sym],
206            [self.d2H_dxdeta_sym, self.d2H_dyeta_sym, self.d2H_dxideta_sym, self.d2H_deta2_sym]
207        ])
208
209        # --- Convert to numerical functions ---
210        self._lambdify_functions()

Initialization with complete derivative computation for Jacobian evolution

Parameters

symbol : sympy expression Hamiltonian H(x, y, ξ, η) x_sym, y_sym : sympy symbols Position coordinates xi_sym, eta_sym : sympy symbols Momentum coordinates hbar : float Reduced Planck constant (for quantum aspects)

H_sym
x_sym
y_sym
xi_sym
eta_sym
hbar
dH_dx_sym
dH_dy_sym
dH_dxi_sym
dH_deta_sym
d2H_dx2_sym
d2H_dy2_sym
d2H_dxi2_sym
d2H_deta2_sym
d2H_dxdy_sym
d2H_dxdxi_sym
d2H_dxdeta_sym
d2H_dydxi_sym
d2H_dyeta_sym
d2H_dxideta_sym
Hessian
def compute_geodesic( self, x0: float, y0: float, xi0: float, eta0: float, t_max: float, n_points: int = 500) -> src.geometry_2d.Geodesic2D:
277    def compute_geodesic(self, x0: float, y0: float, 
278                        xi0: float, eta0: float,
279                        t_max: float, n_points: int = 500) -> Geodesic2D:
280        """
281        Compute a geodesic with full Jacobian evolution for caustic detection
282        Parameters
283        ----------
284        x0, y0 : float
285            Initial position
286        xi0, eta0 : float
287            Initial momentum
288        t_max : float
289            Final time
290        n_points : int
291            Number of sampling points
292        Returns
293        -------
294        Geodesic2D
295            Structure containing trajectory and caustic analysis
296        """
297        # Initial condition: position, momentum + identity Jacobian
298        z0 = np.zeros(20)
299        z0[0:4] = [x0, y0, xi0, eta0]
300        z0[4:] = np.eye(4).flatten()
301        t_eval = np.linspace(0, t_max, n_points)
302        sol = solve_ivp(
303            self._hamiltonian_system_augmented,
304            [0, t_max], z0, t_eval=t_eval,
305            method='DOP853', rtol=1e-9, atol=1e-12
306        )
307        if not sol.success:
308            print(f"Warning: Integration failed for ({x0}, {y0}, {xi0}, {eta0})")
309        # Extract trajectory data
310        x_traj = sol.y[0]
311        y_traj = sol.y[1]
312        xi_traj = sol.y[2]
313        eta_traj = sol.y[3]
314        # Evaluate energy
315        H_vals = self.H_num(x_traj, y_traj, xi_traj, eta_traj)
316        # Extract and reshape Jacobian matrices
317        J_mats = np.zeros((n_points, 4, 4))
318        for i in range(n_points):
319            J_mats[i] = sol.y[4:, i].reshape((4, 4))
320        # Submatrix for caustic detection: ∂(x,y)/∂(ξ₀,η₀)
321        caustic_matrix = J_mats[:, 0:2, 2:4]
322        # Determinant for caustic detection
323        det_caustic = np.zeros(n_points)
324        for i in range(n_points):
325            det_caustic[i] = np.linalg.det(caustic_matrix[i])
326        # Detect caustic indices (sign change)
327        caustic_indices = np.where(np.diff(np.sign(det_caustic)))[0]
328        return Geodesic2D(
329            t=sol.t,
330            x=x_traj,
331            y=y_traj,
332            xi=xi_traj,
333            eta=eta_traj,
334            H=H_vals,
335            J_full=J_mats,
336            det_caustic=det_caustic,
337            caustic_indices=caustic_indices
338        )

Compute a geodesic with full Jacobian evolution for caustic detection

Parameters

x0, y0 : float Initial position xi0, eta0 : float Initial momentum t_max : float Final time n_points : int Number of sampling points

Returns

Geodesic2D Structure containing trajectory and caustic analysis

def find_periodic_orbits_2d( self, energy: float, x_range: Tuple[float, float], y_range: Tuple[float, float], xi_range: Tuple[float, float], eta_range: Tuple[float, float], n_attempts: int = 30) -> List[src.geometry_2d.PeriodicOrbit2D]:
340    def find_periodic_orbits_2d(self, energy: float,
341                               x_range: Tuple[float, float],
342                               y_range: Tuple[float, float],
343                               xi_range: Tuple[float, float],
344                               eta_range: Tuple[float, float],
345                               n_attempts: int = 30) -> List[PeriodicOrbit2D]:
346        """
347        Search for periodic orbits with Maslov index computation
348        """
349        orbits = []
350        # Sample configuration space
351        n_samples = int(np.sqrt(n_attempts))
352        x_samples = np.linspace(x_range[0], x_range[1], n_samples)
353        y_samples = np.linspace(y_range[0], y_range[1], n_samples)
354        for x0 in x_samples:
355            for y0 in y_samples:
356                # Test different momentum directions
357                angles = np.linspace(0, 2*np.pi, 8)
358                for angle in angles:
359                    for r in np.linspace(0.5, 3, 3):
360                        xi0_guess = r * np.cos(angle)
361                        eta0_guess = r * np.sin(angle)
362                        try:
363                            # Energy check
364                            E_test = self.H_num(x0, y0, xi0_guess, eta0_guess)
365                            if abs(E_test - energy) > 0.5:
366                                continue
367                            # Compute geodesic
368                            geo = self.compute_geodesic(x0, y0, xi0_guess, eta0_guess, 15, 1500)
369                            # Search for return points
370                            distances = np.sqrt((geo.x - x0)**2 + (geo.y - y0)**2 +
371                                              (geo.xi - xi0_guess)**2 + (geo.eta - eta0_guess)**2)
372                            minima = []
373                            for i in range(10, len(distances)-10):
374                                if (distances[i] < distances[i-1] and
375                                    distances[i] < distances[i+1] and
376                                    distances[i] < 0.05):
377                                    minima.append(i)
378                            if minima:
379                                idx = minima[0]
380                                period = geo.t[idx]
381                                if period > 0.2 and distances[idx] < 0.05:
382                                    # Compute action
383                                    x_cyc = geo.x[:idx+1]
384                                    y_cyc = geo.y[:idx+1]
385                                    xi_cyc = geo.xi[:idx+1]
386                                    eta_cyc = geo.eta[:idx+1]
387                                    t_cyc = geo.t[:idx+1]
388                                    dx_dt = np.gradient(x_cyc, t_cyc)
389                                    dy_dt = np.gradient(y_cyc, t_cyc)
390                                    action = np.trapz(xi_cyc * dx_dt + eta_cyc * dy_dt, t_cyc)
391                                    # Compute Maslov index (number of caustic crossings)
392                                    maslov_index = len([i for i in geo.caustic_indices if i < idx])
393                                    # Compute stability
394                                    stab1 = self._compute_stability_2d(x0, y0, xi0_guess, eta0_guess, period)
395                                    orbits.append(PeriodicOrbit2D(
396                                        x0=x0, y0=y0,
397                                        xi0=xi0_guess, eta0=eta0_guess,
398                                        period=period,
399                                        action=action,
400                                        energy=energy,
401                                        stability_1=stab1,
402                                        stability_2=0.0,
403                                        x_cycle=x_cyc,
404                                        y_cycle=y_cyc,
405                                        xi_cycle=xi_cyc,
406                                        eta_cycle=eta_cyc,
407                                        t_cycle=t_cyc,
408                                        maslov_index=maslov_index
409                                    ))
410                        except Exception as e:
411                            continue
412        return self._remove_duplicate_orbits_2d(orbits)

Search for periodic orbits with Maslov index computation

def detect_caustic_structures( self, geodesics: List[src.geometry_2d.Geodesic2D], t_fixed: float) -> List[src.geometry_2d.CausticStructure]:
455    def detect_caustic_structures(self, geodesics: List[Geodesic2D], 
456                                 t_fixed: float) -> List[CausticStructure]:
457        """
458        Advanced caustic structure detection with classification
459        """
460        caustic_points = []
461        for geo in geodesics:
462            # Find closest time to t_fixed
463            idx = np.argmin(np.abs(geo.t - t_fixed))
464            # Check if near a caustic
465            if abs(geo.det_caustic[idx]) < 0.1:
466                # Classify caustic type
467                caustic_type = self._classify_caustic(geo, idx)
468                # Compute singularity strength
469                strength = 1.0 / (abs(geo.det_caustic[idx]) + 0.01)
470                caustic_points.append({
471                    'x': geo.x[idx],
472                    'y': geo.y[idx],
473                    'energy': geo.energy,
474                    'type': caustic_type,
475                    'strength': strength
476                })
477        if len(caustic_points) < 3:
478            return []
479        # Cluster points into caustic structures
480        caustic_structures = self._cluster_caustic_points(caustic_points, t_fixed)
481        return caustic_structures

Advanced caustic structure detection with classification

def compute_phase_space_volume( self, E_max: float, x_range: tuple, y_range: tuple, xi_range: tuple, eta_range: tuple, n_samples: int = 200000) -> float:
554    def compute_phase_space_volume(self, E_max: float, x_range: tuple, y_range: tuple,
555                                 xi_range: tuple, eta_range: tuple, 
556                                 n_samples: int = 200000) -> float:
557        """Monte Carlo estimation of phase space volume for H ≤ E_max"""
558        # Generate random samples
559        x_samples = np.random.uniform(x_range[0], x_range[1], n_samples)
560        y_samples = np.random.uniform(y_range[0], y_range[1], n_samples)
561        xi_samples = np.random.uniform(xi_range[0], xi_range[1], n_samples)
562        eta_samples = np.random.uniform(eta_range[0], eta_range[1], n_samples)
563        # Evaluate Hamiltonian
564        H_vals = self.H_num(x_samples, y_samples, xi_samples, eta_samples)
565        # Count points where H ≤ E_max
566        volume_ratio = np.mean(H_vals <= E_max)
567        # Total phase space volume
568        total_volume = ((x_range[1]-x_range[0]) * (y_range[1]-y_range[0]) * 
569                       (xi_range[1]-xi_range[0]) * (eta_range[1]-eta_range[0]))
570        return volume_ratio * total_volume

Monte Carlo estimation of phase space volume for H ≤ E_max

class SymbolVisualizer2D:
 575class SymbolVisualizer2D:
 576    """
 577    Complete visualization combining geometric and physical aspects
 578    """
 579    def __init__(self, geometry: SymbolGeometry2D):
 580        self.geo = geometry
 581
 582    def visualize_complete(self,
 583                          x_range: Tuple[float, float],
 584                          y_range: Tuple[float, float],
 585                          xi_range: Tuple[float, float],
 586                          eta_range: Tuple[float, float],
 587                          geodesics_params: List[Tuple],
 588                          E_range: Optional[Tuple[float, float]] = None,
 589                          hbar: float = 1.0,
 590                          resolution: int = 50) -> Tuple:
 591        """
 592        Create a complete 18-panel visualization combining geometry and physics
 593        Parameters
 594        ----------
 595        x_range, y_range : tuple
 596            Configuration space domain
 597        xi_range, eta_range : tuple
 598            Momentum space domain
 599        geodesics_params : list
 600            Geodesic parameters: (x0, y0, xi0, eta0, t_max, color)
 601        E_range : tuple, optional
 602            Energy interval for spectral analysis
 603        hbar : float
 604            Reduced Planck constant
 605        resolution : int
 606            Grid resolution
 607        Returns
 608        -------
 609        fig, geodesics, periodic_orbits, caustics
 610        """
 611        # Compute geodesics with caustic detection
 612        geodesics = self._compute_geodesics(geodesics_params)
 613        # Search for periodic orbits
 614        periodic_orbits = []
 615        if E_range:
 616            energies = np.linspace(E_range[0], E_range[1], 5)
 617            for E in energies:
 618                orbits = self.geo.find_periodic_orbits_2d(
 619                    E, x_range, y_range, xi_range, eta_range, n_attempts=20
 620                )
 621                periodic_orbits.extend(orbits)
 622        # Detect caustic structures
 623        caustics = []
 624        if geodesics:
 625            t_samples = np.linspace(0, geodesics[0].t[-1], 5)
 626            for t in t_samples:
 627                caustics.extend(self.geo.detect_caustic_structures(geodesics, t))
 628        # Create full figure
 629        fig = self._create_complete_figure(
 630            E_range, x_range, y_range, xi_range, eta_range,
 631            geodesics, periodic_orbits, caustics, hbar, resolution
 632        )
 633        return fig, geodesics, periodic_orbits, caustics
 634    
 635    def _compute_geodesics(self, params):
 636        """Compute geodesics with caustic detection"""
 637        geodesics = []
 638        for p in params:
 639            x0, y0, xi0, eta0, t_max = p[:5]
 640            geo = self.geo.compute_geodesic(x0, y0, xi0, eta0, t_max)
 641            geo.color = p[5] if len(p) > 5 else 'blue'
 642            geodesics.append(geo)
 643        return geodesics
 644
 645    
 646    def _create_complete_figure(self, E_range, x_range, y_range, xi_range, eta_range,
 647                               geodesics, periodic_orbits, caustics, hbar, resolution):
 648        """Creates an adaptive multi-panel figure: only relevant panels are displayed."""
 649        
 650        # --- List of panels with explicit call signatures ---
 651        panels_to_plot = []
 652    
 653        # Always safe to plot if data exists
 654        if geodesics:
 655            panels_to_plot.append(lambda ax_spec: self._plot_energy_surface_2d(fig, ax_spec, x_range, y_range, geodesics, resolution))
 656            panels_to_plot.append(lambda ax_spec: self._plot_configuration_space(fig, ax_spec, geodesics, caustics))
 657            panels_to_plot.append(lambda ax_spec: self._plot_phase_projection_x(fig, ax_spec, geodesics))
 658            panels_to_plot.append(lambda ax_spec: self._plot_phase_projection_y(fig, ax_spec, geodesics))
 659            panels_to_plot.append(lambda ax_spec: self._plot_momentum_space(fig, ax_spec, geodesics))
 660            panels_to_plot.append(lambda ax_spec: self._plot_vector_field_2d(fig, ax_spec, x_range, y_range, geodesics, resolution))
 661            panels_to_plot.append(lambda ax_spec: self._plot_group_velocity_2d(fig, ax_spec, x_range, y_range, geodesics, resolution))
 662            panels_to_plot.append(lambda ax_spec: self._plot_caustic_curves_2d(fig, ax_spec, geodesics, caustics))
 663            panels_to_plot.append(lambda ax_spec: self._plot_jacobian_evolution(fig, ax_spec, geodesics))
 664            panels_to_plot.append(lambda ax_spec: self._plot_energy_conservation_2d(fig, ax_spec, geodesics))
 665            panels_to_plot.append(lambda ax_spec: self._plot_poincare_x(fig, ax_spec, geodesics))
 666            panels_to_plot.append(lambda ax_spec: self._plot_poincare_y(fig, ax_spec, geodesics))
 667            panels_to_plot.append(lambda ax_spec: self._plot_caustic_network(fig, ax_spec, x_range, y_range, geodesics))
 668    
 669        if geodesics and caustics:
 670            pass  # already handled above
 671    
 672        if periodic_orbits:
 673            panels_to_plot.append(lambda ax_spec: self._plot_periodic_orbits_3d(fig, ax_spec, periodic_orbits))
 674            panels_to_plot.append(lambda ax_spec: self._plot_action_energy_2d(fig, ax_spec, periodic_orbits))
 675            panels_to_plot.append(lambda ax_spec: self._plot_torus_quantization(fig, ax_spec, periodic_orbits, hbar))
 676            if len(periodic_orbits) > 2:
 677                panels_to_plot.append(lambda ax_spec: self._plot_level_spacing_2d(fig, ax_spec, periodic_orbits))
 678    
 679        if periodic_orbits and E_range:
 680            panels_to_plot.append(lambda ax_spec: self._plot_spectral_density_with_caustics(fig, ax_spec, periodic_orbits, E_range))
 681    
 682        # Always plot Maslov (demo)
 683        panels_to_plot.append(lambda ax_spec: self._plot_maslov_index_phase_shifts(fig, ax_spec, geodesics, caustics))
 684    
 685        if E_range:
 686            panels_to_plot.append(lambda ax_spec: self._plot_phase_space_volume(fig, ax_spec, E_range, x_range, y_range, xi_range, eta_range))
 687    
 688        # --- Handle empty case ---
 689        if not panels_to_plot:
 690            fig, ax = plt.subplots(figsize=(10, 6))
 691            ax.text(0.5, 0.5, "No panels to display for this Hamiltonian.",
 692                    ha='center', va='center', fontsize=16, transform=ax.transAxes)
 693            ax.set_axis_off()
 694            return fig
 695    
 696        # --- Dynamic layout ---
 697        n = len(panels_to_plot)
 698        if n <= 5:
 699            cols, rows = n, 1
 700        elif n <= 10:
 701            cols, rows = 5, 2
 702        elif n <= 15:
 703            cols, rows = 5, 3
 704        else:
 705            cols, rows = 5, (n + 4) // 5
 706    
 707        figsize = (4.8 * cols, 4.0 * rows)
 708        fig = plt.figure(figsize=figsize)
 709        gs = GridSpec(rows, cols, figure=fig, hspace=0.5, wspace=0.3)
 710        plt.suptitle(f'Geometric and Semiclassical Atlas: H = {self.geo.H_sym} (ℏ={hbar})',
 711                     fontsize=18, fontweight='bold', y=0.98)
 712    
 713        # --- Plot all panels ---
 714        for idx, plot_cmd in enumerate(panels_to_plot):
 715            if idx >= rows * cols:
 716                break
 717            row = idx // cols
 718            col = idx % cols
 719            subplot_spec = gs[row, col]
 720            try:
 721                plot_cmd(subplot_spec)
 722            except Exception as e:
 723                ax = fig.add_subplot(subplot_spec)
 724                ax.text(0.5, 0.5, f"[Error]\n{type(e).__name__}", ha='center', va='center', color='red')
 725                ax.set_axis_off()
 726    
 727        plt.tight_layout(rect=[0, 0.02, 1, 0.95])
 728        return fig
 729
 730    # ======== DETAILED VISUALIZATION METHODS ========
 731    def _plot_energy_surface_2d(self, fig, subplot_spec, x_range, y_range, geodesics, res):
 732        """Energy surface H(x,y) at fixed (ξ,η)"""
 733        ax = fig.add_subplot(subplot_spec, projection='3d')
 734        x = np.linspace(x_range[0], x_range[1], res)
 735        y = np.linspace(y_range[0], y_range[1], res)
 736        X, Y = np.meshgrid(x, y)
 737        # Evaluate at reference momentum
 738        xi_ref, eta_ref = 1.0, 1.0
 739        Z = np.zeros_like(X)
 740        for i in range(X.shape[0]):
 741            for j in range(X.shape[1]):
 742                try:
 743                    Z[i,j] = self.geo.H_num(X[i,j], Y[i,j], xi_ref, eta_ref)
 744                except:
 745                    Z[i,j] = np.nan
 746        # Surface with transparency to see geodesics
 747        ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.6, edgecolor='none')
 748        # Geodesics on the surface
 749        for geo in geodesics[:5]:
 750            H_geo = np.array([self.geo.H_num(geo.x[i], geo.y[i], xi_ref, eta_ref)
 751                             for i in range(len(geo.t))])
 752            color = getattr(geo, 'color', 'red')
 753            ax.plot(geo.x, geo.y, H_geo, color=color, linewidth=2.5)
 754        ax.set_xlabel('x')
 755        ax.set_ylabel('y')
 756        ax.set_zlabel('H')
 757        ax.set_title('Energy Surface\nH(x,y,ξ₀,η₀)', fontweight='bold', fontsize=10)
 758        ax.view_init(elev=25, azim=-45)
 759    
 760    def _plot_configuration_space(self, fig, subplot_spec, geodesics, caustics):
 761        """Configuration space (x,y) with trajectories and caustics"""
 762        ax = fig.add_subplot(subplot_spec)
 763        
 764        # Trajectories - use thinner lines and lighter colors for better visibility
 765        for geo in geodesics:
 766            color = getattr(geo, 'color', 'blue')
 767            ax.plot(geo.x, geo.y, color=color, linewidth=1.5, alpha=0.7, zorder=5)
 768            ax.scatter([geo.x[0]], [geo.y[0]], color=color, s=80, 
 769                      marker='o', edgecolors='black', linewidths=1.5, zorder=10)
 770        
 771        # Caustic points on trajectories - keep as stars but reduce size slightly
 772        for geo in geodesics:
 773            caust_x, caust_y = geo.caustic_points
 774            if len(caust_x) > 0:
 775                ax.scatter(caust_x, caust_y, c='red', s=80, marker='*',  # Reduced from 120
 776                          edgecolors='darkred', linewidths=1.0, zorder=15,
 777                          label='Caustic points')
 778        
 779        # Caustic structures - use smaller, more subtle markers
 780        for caust in caustics:
 781            color_map = {'fold': 'red', 'cusp': 'magenta', 'swallowtail': 'orange'}
 782            color = color_map.get(caust.type, 'red')
 783            # Use a small circle or dot instead of a large X
 784            marker = 'o'  # You can also try '.' for even smaller dots
 785            # Reduce size significantly and increase transparency
 786            size = 30  # Fixed size for clarity, or use: max(15, min(50, 80 * caust.strength / 2))
 787            alpha_val = 0.5  # More transparent to avoid obscuring trajectories
 788            
 789            ax.scatter(caust.x, caust.y, c=color, s=size, marker=marker,
 790                      edgecolors='none',  # Remove edge for cleaner look
 791                      linewidths=0, alpha=alpha_val, zorder=12,  # zorder between traj and points
 792                      label=f'Caustic {caust.type} (μ={caust.maslov_index})')
 793        
 794        ax.set_xlabel('x')
 795        ax.set_ylabel('y')
 796        ax.set_title('Configuration Space\n★ = caustics', fontweight='bold', fontsize=10)
 797        ax.grid(True, alpha=0.3)
 798        ax.set_aspect('equal')
 799        
 800        # Legend without duplicates
 801        handles, labels = ax.get_legend_handles_labels()
 802        by_label = dict(zip(labels, handles))
 803        if by_label:
 804            ax.legend(by_label.values(), by_label.keys(), fontsize=8, loc='upper right')
 805    
 806    def _plot_jacobian_evolution(self, fig, subplot_spec, geodesics):
 807        """Evolution of Jacobian determinant with caustic detection"""
 808        ax = fig.add_subplot(subplot_spec)
 809        for geo in geodesics:
 810            color = getattr(geo, 'color', 'blue')
 811            ax.plot(geo.t, geo.det_caustic, color=color, linewidth=2.5, alpha=0.9,
 812                   label=f'E={geo.energy:.2f}')
 813            # Mark caustic points
 814            for idx in geo.caustic_indices:
 815                ax.scatter(geo.t[idx], geo.det_caustic[idx], s=100, marker='*',
 816                          color='red', edgecolor='darkred', zorder=10)
 817        ax.axhline(0, color='red', linestyle='--', linewidth=2, alpha=0.7)
 818        ax.set_xlabel('Time t')
 819        ax.set_ylabel('det(∂(x,y)/∂(ξ₀,η₀))')
 820        ax.set_title('Jacobian Determinant\nZeros = caustics', fontweight='bold', fontsize=10)
 821        ax.grid(True, alpha=0.3)
 822        ax.legend(fontsize=8)
 823    
 824    def _plot_maslov_index_phase_shifts(self, fig, subplot_spec, geodesics, caustics):
 825        """Visualization of phase shifts due to Maslov index"""
 826        ax = fig.add_subplot(subplot_spec)
 827        # Simulate wavefunction crossing caustics
 828        x_demo = np.linspace(-4, 4, 1000)
 829        k = 2.0  # Wavenumber
 830        # Free wavefunction (before caustic)
 831        psi_free = np.exp(1j * k * x_demo**2 / 2)
 832        # Simulate phase shifts at caustics
 833        caustic_positions = [-2.0, 0.0, 2.0]  # Caustic positions
 834        maslov_indices = [1, 2, 1]  # Maslov index for each caustic
 835        psi_with_shifts = np.zeros_like(psi_free, dtype=complex)
 836        current_phase = 0.0
 837        for i, x in enumerate(x_demo):
 838            # Check if crossing a caustic
 839            for j, caust_x in enumerate(caustic_positions):
 840                if abs(x - caust_x) < 0.05:
 841                    current_phase -= maslov_indices[j] * np.pi / 2
 842            psi_with_shifts[i] = psi_free[i] * np.exp(1j * current_phase)
 843        # Plot real parts
 844        ax.plot(x_demo, np.real(psi_free), 'b-', alpha=0.8, linewidth=2, 
 845                label='Re[ψ] before caustics')
 846        ax.plot(x_demo, np.real(psi_with_shifts), 'r-', alpha=0.8, linewidth=2, 
 847                label='Re[ψ] after caustics')
 848        # Mark caustic positions
 849        for i, caust_x in enumerate(caustic_positions):
 850            ax.axvline(caust_x, color='k', linestyle='--', alpha=0.7,
 851                      label=f'Caustic μ={maslov_indices[i]}')
 852        ax.set_xlabel('Position x')
 853        ax.set_ylabel('Re[ψ(x)]')
 854        ax.set_title('Maslov Index\nPhase shifts at caustics', fontweight='bold', fontsize=10)
 855        ax.set_ylim(-1.5, 1.5)
 856        ax.grid(True, alpha=0.3)
 857        ax.legend(fontsize=8, loc='upper right')
 858    
 859    def _plot_spectral_density_with_caustics(self, fig, subplot_spec, periodic_orbits, E_range):
 860        """Spectral density with caustic corrections"""
 861        ax = fig.add_subplot(subplot_spec)
 862        if not periodic_orbits:
 863            ax.text(0.5, 0.5, 'No periodic orbits', 
 864                   ha='center', va='center', transform=ax.transAxes)
 865            return
 866        # Sort orbits by energy
 867        orbits_sorted = sorted(periodic_orbits, key=lambda x: x.energy)
 868        energies = np.array([orb.energy for orb in orbits_sorted])
 869        periods = np.array([orb.period for orb in orbits_sorted])
 870        # Compute state density ρ(E) = T(E)/(2π) for integrable systems
 871        if len(energies) > 1:
 872            dE = np.diff(energies)
 873            dT = np.diff(periods)
 874            rho_E = np.zeros_like(energies)
 875            rho_E[1:-1] = (periods[2:] - periods[:-2]) / (energies[2:] - energies[:-2])
 876            if len(rho_E) > 2:
 877                rho_E[0] = (periods[1] - periods[0]) / (energies[1] - energies[0])
 878                rho_E[-1] = (periods[-1] - periods[-2]) / (energies[-1] - energies[-2])
 879            rho_E = np.maximum(rho_E, 0)  # Avoid negative values
 880            # Caustic correction (oscillatory terms)
 881            rho_osc = np.zeros_like(rho_E)
 882            for orb in orbits_sorted:
 883                # Amplitude depending on Maslov index
 884                amp = 0.3 * np.exp(-orb.maslov_index/2) * orb.period
 885                phase = orb.action / self.geo.hbar - np.pi * orb.maslov_index / 2
 886                idx = np.argmin(np.abs(energies - orb.energy))
 887                if 0 <= idx < len(rho_osc):
 888                    rho_osc[idx] += amp * np.cos(phase)
 889            # Smooth curve
 890            E_fine = np.linspace(E_range[0], E_range[1], 500)
 891            from scipy.interpolate import interp1d
 892            try:
 893                interp_rho = interp1d(energies, rho_E, kind='cubic', fill_value="extrapolate")
 894                interp_osc = interp1d(energies, rho_osc, kind='cubic', fill_value="extrapolate")
 895                rho_smooth = np.maximum(0, interp_rho(E_fine))
 896                rho_osc_smooth = interp_osc(E_fine)
 897                # Plot components
 898                ax.plot(E_fine, rho_smooth, 'k-', linewidth=2.5, 
 899                       label='Smooth (Weyl)')
 900                ax.plot(E_fine, rho_smooth + rho_osc_smooth, 'b-', linewidth=2,
 901                       label='Total with caustics')
 902                ax.fill_between(E_fine, rho_smooth, rho_smooth + rho_osc_smooth, 
 903                               where=rho_osc_smooth>0, color='#ff9999', alpha=0.4,
 904                               label='Caustic corrections')
 905            except:
 906                ax.plot(energies, rho_E, 'b-o', linewidth=2, label='State density ρ(E)')
 907        ax.set_xlabel('Energy E')
 908        ax.set_ylabel('ρ(E)')
 909        ax.set_title('Spectral Density\nwith caustic corrections', fontweight='bold', fontsize=10)
 910        ax.grid(True, alpha=0.3)
 911        ax.legend(fontsize=8)
 912    
 913    def _plot_phase_space_volume(self, fig, subplot_spec, E_range, x_range, y_range, xi_range, eta_range):
 914        """Phase space volume via Monte Carlo"""
 915        ax = fig.add_subplot(subplot_spec)
 916        # Compute volume for different energies
 917        E_vals = np.linspace(E_range[0], E_range[1], 8)
 918        volumes = []
 919        print("Computing phase space volume (Monte Carlo)...")
 920        for E in E_vals:
 921            vol = self.geo.compute_phase_space_volume(E, x_range, y_range, xi_range, eta_range, n_samples=50000)
 922            volumes.append(vol)
 923            print(f"  E={E:.2f}, Volume={vol:.4f}")
 924        # Weyl law: N(E) ~ Vol/(2πℏ)²
 925        d = 2  # Dimension
 926        weyl_constant = (2 * np.pi * self.geo.hbar) ** d
 927        N_weyl = np.array(volumes) / weyl_constant
 928        ax.plot(E_vals, N_weyl, 'b-o', linewidth=2.5, markersize=8, 
 929                label=f'Weyl law: N(E) ~ Vol/(2πℏ)²', color='#1f77b4')
 930        # Conceptual caustic correction
 931        if len(E_vals) > 3:
 932            oscillation_freq = 5 / (E_range[1] - E_range[0])
 933            correction = 0.15 * N_weyl * np.sin(2 * np.pi * oscillation_freq * (E_vals - E_vals[0]) + 0.7)
 934            N_corrected = N_weyl + correction
 935            from scipy.ndimage import gaussian_filter1d
 936            N_corrected_smooth = gaussian_filter1d(N_corrected, sigma=1.0)
 937            ax.plot(E_vals, N_corrected_smooth, 'r--', linewidth=2, 
 938                   label="With caustic corrections", alpha=0.9)
 939        ax.set_xlabel('Energy E')
 940        ax.set_ylabel('N(E) (Number of states)')
 941        ax.set_title('Phase Space Volume\n(Monte Carlo)', fontweight='bold', fontsize=10)
 942        ax.grid(True, alpha=0.3)
 943        ax.legend(fontsize=8)
 944    
 945    def _plot_caustic_network(self, fig, subplot_spec, x_range, y_range, geodesics):
 946        """Caustic network with multiple initial conditions"""
 947        ax = fig.add_subplot(subplot_spec)
 948        if not geodesics:
 949            ax.text(0.5, 0.5, 'No geodesics', 
 950                   ha='center', va='center', transform=ax.transAxes)
 951            return
 952        # Use first geodesic as reference
 953        E_ref = geodesics[0].energy
 954        t_max = geodesics[0].t[-1]
 955        # Generate trajectory family
 956        n_family = 15
 957        x0_vals = np.linspace(x_range[0], x_range[1], n_family)
 958        caustic_points = []
 959        for x0 in x0_vals:
 960            try:
 961                # Solve for y0, xi0, eta0 keeping energy constant
 962                def energy_eq(vars):
 963                    y_val, xi_val, eta_val = vars
 964                    return self.geo.H_num(x0, y_val, xi_val, eta_val) - E_ref
 965                # Use initial values of first geodesic as guess
 966                y0_guess = geodesics[0].y[0]
 967                xi0_guess = geodesics[0].xi[0]
 968                eta0_guess = geodesics[0].eta[0]
 969                sol = fsolve(energy_eq, [y0_guess, xi0_guess, eta0_guess])
 970                if np.all(np.isfinite(sol)):
 971                    y0_new, xi0_new, eta0_new = sol
 972                    # Compute trajectory
 973                    geo = self.geo.compute_geodesic(x0, y0_new, xi0_new, eta0_new, t_max, n_points=300)
 974                    # Plot trajectory
 975                    ax.plot(geo.x, geo.y, color='blue', alpha=0.3, linewidth=1)
 976                    # Collect caustic points
 977                    caust_x, caust_y = geo.caustic_points
 978                    for i in range(len(caust_x)):
 979                        caustic_points.append((caust_x[i], caust_y[i]))
 980            except Exception as e:
 981                continue
 982        # Plot caustic points
 983        if caustic_points:
 984            caustic_points = np.array(caustic_points)
 985            ax.scatter(caustic_points[:, 0], caustic_points[:, 1], 
 986                      s=30, c='red', alpha=0.8, edgecolor='none',
 987                      label='Caustic points')
 988        ax.set_xlabel('x')
 989        ax.set_ylabel('y')
 990        ax.set_title('Caustic Network\n(Multiple initial conditions)', fontweight='bold', fontsize=10)
 991        ax.set_xlim(x_range)
 992        ax.set_ylim(y_range)
 993        ax.grid(True, alpha=0.3)
 994        ax.legend(fontsize=8)
 995    
 996    # ======== STANDARD VISUALIZATION METHODS (similar to v1) ========
 997    # Following methods are similar to v1 but enhanced
 998    # to integrate caustics and new data structures
 999    def _plot_phase_projection_x(self, fig, subplot_spec, geodesics):
1000        """Phase space projection (x,ξ)"""
1001        ax = fig.add_subplot(subplot_spec)
1002        for geo in geodesics:
1003            color = getattr(geo, 'color', 'blue')
1004            ax.plot(geo.x, geo.xi, color=color, linewidth=2, alpha=0.8)
1005            ax.scatter([geo.x[0]], [geo.xi[0]], color=color, s=80,
1006                      marker='o', edgecolors='black', linewidths=1.5)
1007        ax.set_xlabel('x')
1008        ax.set_ylabel('ξ')
1009        ax.set_title('Phase Space (x,ξ)', fontweight='bold', fontsize=10)
1010        ax.grid(True, alpha=0.3)
1011    
1012    def _plot_phase_projection_y(self, fig, subplot_spec, geodesics):
1013        """Phase space projection (y,η)"""
1014        ax = fig.add_subplot(subplot_spec)
1015        for geo in geodesics:
1016            color = getattr(geo, 'color', 'blue')
1017            ax.plot(geo.y, geo.eta, color=color, linewidth=2, alpha=0.8)
1018            ax.scatter([geo.y[0]], [geo.eta[0]], color=color, s=80,
1019                      marker='o', edgecolors='black', linewidths=1.5)
1020        ax.set_xlabel('y')
1021        ax.set_ylabel('η')
1022        ax.set_title('Phase Space (y,η)', fontweight='bold', fontsize=10)
1023        ax.grid(True, alpha=0.3)
1024    
1025    def _plot_momentum_space(self, fig, subplot_spec, geodesics):
1026        """Momentum space (ξ,η)"""
1027        ax = fig.add_subplot(subplot_spec)
1028        for geo in geodesics:
1029            color = getattr(geo, 'color', 'blue')
1030            ax.plot(geo.xi, geo.eta, color=color, linewidth=2, alpha=0.8)
1031            ax.scatter([geo.xi[0]], [geo.eta[0]], color=color, s=80,
1032                      marker='o', edgecolors='black', linewidths=1.5)
1033        ax.set_xlabel('ξ')
1034        ax.set_ylabel('η')
1035        ax.set_title('Momentum Space\n(ξ,η)', fontweight='bold', fontsize=10)
1036        ax.grid(True, alpha=0.3)
1037        ax.set_aspect('equal')
1038    
1039    def _plot_vector_field_2d(self, fig, subplot_spec, x_range, y_range, geodesics, res):
1040        """Vector field in configuration space"""
1041        ax = fig.add_subplot(subplot_spec)
1042        x = np.linspace(x_range[0], x_range[1], res//2)
1043        y = np.linspace(y_range[0], y_range[1], res//2)
1044        X, Y = np.meshgrid(x, y)
1045        # Evaluate vector field at reference momentum
1046        xi_ref, eta_ref = 1.0, 1.0
1047        VX = np.zeros_like(X)
1048        VY = np.zeros_like(Y)
1049        for i in range(X.shape[0]):
1050            for j in range(X.shape[1]):
1051                try:
1052                    VX[i,j] = self.geo.dH_dxi_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1053                    VY[i,j] = self.geo.dH_deta_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1054                except:
1055                    VX[i,j] = VY[i,j] = np.nan
1056        # Magnitude for coloring
1057        magnitude = np.sqrt(VX**2 + VY**2)
1058        magnitude[magnitude == 0] = 1
1059        # Normalized vector field
1060        ax.quiver(X, Y, VX/magnitude, VY/magnitude, magnitude, 
1061                 cmap='plasma', alpha=0.7, scale=30)
1062        # Overlay geodesics
1063        for geo in geodesics[:5]:
1064            color = getattr(geo, 'color', 'white')
1065            ax.plot(geo.x, geo.y, color=color, linewidth=2.5, alpha=0.9)
1066        ax.set_xlabel('x')
1067        ax.set_ylabel('y')
1068        ax.set_title('Vector Field\nFlow in configuration space', fontweight='bold', fontsize=10)
1069        ax.set_aspect('equal')
1070    
1071    def _plot_group_velocity_2d(self, fig, subplot_spec, x_range, y_range, geodesics, res):
1072        """Group velocity magnitude |∇_p H|"""
1073        ax = fig.add_subplot(subplot_spec)
1074        x = np.linspace(x_range[0], x_range[1], res)
1075        y = np.linspace(y_range[0], y_range[1], res)
1076        X, Y = np.meshgrid(x, y)
1077        # Group velocity at reference momentum
1078        xi_ref, eta_ref = 1.0, 1.0
1079        V_mag = np.zeros_like(X)
1080        for i in range(X.shape[0]):
1081            for j in range(X.shape[1]):
1082                try:
1083                    vx = self.geo.dH_dxi_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1084                    vy = self.geo.dH_deta_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1085                    V_mag[i,j] = np.sqrt(vx**2 + vy**2)
1086                except:
1087                    V_mag[i,j] = np.nan
1088        # Heatmap
1089        im = ax.contourf(X, Y, V_mag, levels=20, cmap='hot')
1090        plt.colorbar(im, ax=ax, label='|v_g|')
1091        # Geodesics
1092        for geo in geodesics[:5]:
1093            ax.plot(geo.x, geo.y, 'cyan', linewidth=2, alpha=0.8)
1094        ax.set_xlabel('x')
1095        ax.set_ylabel('y')
1096        ax.set_title('Group Velocity\n|∇_p H|', fontweight='bold', fontsize=10)
1097        ax.set_aspect('equal')
1098    
1099    def _plot_caustic_curves_2d(self, fig, subplot_spec, geodesics, caustics):
1100        """Caustic curves in (x,y) space"""
1101        ax = fig.add_subplot(subplot_spec)
1102        # All geodesics
1103        for geo in geodesics:
1104            color = getattr(geo, 'color', 'lightblue')
1105            ax.plot(geo.x, geo.y, color=color, linewidth=1.5, alpha=0.5)
1106            # Caustic points on each geodesic
1107            caust_x, caust_y = geo.caustic_points
1108            if len(caust_x) > 0:
1109                ax.scatter(caust_x, caust_y, c='red', s=80, marker='*', 
1110                          edgecolors='darkred', linewidths=1.5, zorder=10)
1111        # Complete caustic structures
1112        for caust in caustics:
1113            color_map = {'fold': 'red', 'cusp': 'magenta', 'swallowtail': 'orange'}
1114            color = color_map.get(caust.type, 'red')
1115            # If enough points, plot smoothed curve
1116            if len(caust.x) > 3:
1117                ax.plot(caust.x, caust.y, color=color, linewidth=3, 
1118                       label=f'Caustic {caust.type} (μ={caust.maslov_index})')
1119            else:
1120                ax.scatter(caust.x, caust.y, c=color, s=100, marker='X',
1121                          edgecolors='black', linewidths=1.5,
1122                          label=f'Caustic {caust.type}')
1123        ax.set_xlabel('x')
1124        ax.set_ylabel('y')
1125        ax.set_title('Caustic Curves\n★ = points on geodesics', fontweight='bold', fontsize=10)
1126        ax.grid(True, alpha=0.3)
1127        ax.set_aspect('equal')
1128        # Legend without duplicates
1129        handles, labels = ax.get_legend_handles_labels()
1130        by_label = dict(zip(labels, handles))
1131        if by_label:
1132            ax.legend(by_label.values(), by_label.keys(), fontsize=8)
1133    
1134    def _plot_energy_conservation_2d(self, fig, subplot_spec, geodesics):
1135        """Energy conservation verification"""
1136        ax = fig.add_subplot(subplot_spec)
1137        for geo in geodesics:
1138            color = getattr(geo, 'color', 'blue')
1139            H_var = (geo.H - geo.H[0]) / (np.abs(geo.H[0]) + 1e-10)
1140            ax.semilogy(geo.t, np.abs(H_var) + 1e-16,
1141                       color=color, linewidth=2, label=f'E={geo.H[0]:.2f}')
1142        ax.set_xlabel('Time t')
1143        ax.set_ylabel('|ΔH/H₀|')
1144        ax.set_title('Energy Conservation\nNumerical quality', fontweight='bold', fontsize=10)
1145        ax.legend(fontsize=8)
1146        ax.grid(True, alpha=0.3, which='both')
1147    
1148    def _plot_poincare_x(self, fig, subplot_spec, geodesics):
1149        """Poincaré section (x,ξ) at y=0"""
1150        ax = fig.add_subplot(subplot_spec)
1151        for geo in geodesics:
1152            # Find y=0 crossings
1153            crossings_x = []
1154            crossings_xi = []
1155            for i in range(len(geo.y)-1):
1156                if geo.y[i] * geo.y[i+1] < 0:  # Sign change
1157                    alpha = -geo.y[i] / (geo.y[i+1] - geo.y[i])
1158                    x_cross = geo.x[i] + alpha * (geo.x[i+1] - geo.x[i])
1159                    xi_cross = geo.xi[i] + alpha * (geo.xi[i+1] - geo.xi[i])
1160                    crossings_x.append(x_cross)
1161                    crossings_xi.append(xi_cross)
1162            if crossings_x:
1163                color = getattr(geo, 'color', 'blue')
1164                ax.scatter(crossings_x, crossings_xi, c=color, s=50, alpha=0.7)
1165        ax.set_xlabel('x')
1166        ax.set_ylabel('ξ')
1167        ax.set_title('Poincaré Section\n(x,ξ) at y=0', fontweight='bold', fontsize=10)
1168        ax.grid(True, alpha=0.3)
1169    
1170    def _plot_poincare_y(self, fig, subplot_spec, geodesics):
1171        """Poincaré section (y,η) at x=0"""
1172        ax = fig.add_subplot(subplot_spec)
1173        for geo in geodesics:
1174            # Find x=0 crossings
1175            crossings_y = []
1176            crossings_eta = []
1177            for i in range(len(geo.x)-1):
1178                if geo.x[i] * geo.x[i+1] < 0:
1179                    alpha = -geo.x[i] / (geo.x[i+1] - geo.x[i])
1180                    y_cross = geo.y[i] + alpha * (geo.y[i+1] - geo.y[i])
1181                    eta_cross = geo.eta[i] + alpha * (geo.eta[i+1] - geo.eta[i])
1182                    crossings_y.append(y_cross)
1183                    crossings_eta.append(eta_cross)
1184            if crossings_y:
1185                color = getattr(geo, 'color', 'blue')
1186                ax.scatter(crossings_y, crossings_eta, c=color, s=50, alpha=0.7)
1187        ax.set_xlabel('y')
1188        ax.set_ylabel('η')
1189        ax.set_title('Poincaré Section\n(y,η) at x=0', fontweight='bold', fontsize=10)
1190        ax.grid(True, alpha=0.3)
1191    
1192    def _plot_periodic_orbits_3d(self, fig, subplot_spec, periodic_orbits):
1193        """Periodic orbits in 3D (x,y,t)"""
1194        ax = fig.add_subplot(subplot_spec, projection='3d')
1195        colors = plt.cm.rainbow(np.linspace(0, 1, min(10, len(periodic_orbits))))
1196        for idx, orb in enumerate(periodic_orbits[:10]):  # Limit for clarity
1197            ax.plot(orb.x_cycle, orb.y_cycle, orb.t_cycle,
1198                   color=colors[idx], linewidth=2.5, alpha=0.8)
1199            ax.scatter([orb.x0], [orb.y0], [0], color=colors[idx],
1200                      s=100, marker='o', edgecolors='black', linewidths=2)
1201        ax.set_xlabel('x')
1202        ax.set_ylabel('y')
1203        ax.set_zlabel('t')
1204        ax.set_title('Periodic Orbits\nSpace-time view', fontweight='bold', fontsize=10)
1205    
1206    def _plot_action_energy_2d(self, fig, subplot_spec, periodic_orbits):
1207        """Action vs Energy"""
1208        ax = fig.add_subplot(subplot_spec)
1209        E_orb = [orb.energy for orb in periodic_orbits]
1210        S_orb = [orb.action for orb in periodic_orbits]
1211        T_orb = [orb.period for orb in periodic_orbits]
1212        scatter = ax.scatter(E_orb, S_orb, c=T_orb, s=150,
1213                           cmap='plasma', edgecolors='black', linewidths=1.5)
1214        plt.colorbar(scatter, ax=ax, label='Period T')
1215        ax.set_xlabel('Energy E')
1216        ax.set_ylabel('Action S')
1217        ax.set_title('Action-Energy\nS(E)', fontweight='bold', fontsize=10)
1218        ax.grid(True, alpha=0.3)
1219    
1220    def _plot_torus_quantization(self, fig, subplot_spec, periodic_orbits, hbar):
1221        """Torus quantization (KAM theory)"""
1222        ax = fig.add_subplot(subplot_spec)
1223        E_orb = [orb.energy for orb in periodic_orbits]
1224        S_orb = [orb.action for orb in periodic_orbits]
1225        scatter = ax.scatter(E_orb, S_orb, s=150, c='blue',
1226                           edgecolors='black', linewidths=1.5, label='Orbits')
1227        # EBK quantization for 2D: S_i = 2πℏ(n_i + α_i)
1228        # Simplified for one dimension
1229        E_max = max(E_orb) if E_orb else 10
1230        for n in range(20):
1231            S_quant = 2 * np.pi * hbar * (n + 0.5)
1232            if S_quant < max(S_orb) if S_orb else 10:
1233                ax.axhline(S_quant, color='red', linestyle='--', alpha=0.3)
1234                ax.text(min(E_orb) if E_orb else 0, S_quant, 
1235                       f'n={n}', fontsize=7, color='red')
1236        ax.set_xlabel('Energy E')
1237        ax.set_ylabel('Action S')
1238        ax.set_title('Torus Quantization\nKAM theory', fontweight='bold', fontsize=10)
1239        ax.legend(fontsize=8)
1240        ax.grid(True, alpha=0.3)
1241    
1242    def _plot_level_spacing_2d(self, fig, subplot_spec, periodic_orbits):
1243        """Level spacing distribution"""
1244        ax = fig.add_subplot(subplot_spec)
1245        # Extract unique energies
1246        energies = sorted(set(orb.energy for orb in periodic_orbits))
1247        if len(energies) > 2:
1248            spacings = np.diff(energies)
1249            # Normalize
1250            s_mean = np.mean(spacings)
1251            s_norm = spacings / s_mean
1252            # Histogram
1253            ax.hist(s_norm, bins=15, density=True, alpha=0.7,
1254                   color='blue', edgecolor='black', label='Data')
1255            # Theoretical curves
1256            s = np.linspace(0, np.max(s_norm), 100)
1257            # Poisson (integrable systems)
1258            poisson = np.exp(-s)
1259            ax.plot(s, poisson, 'g--', linewidth=2, label='Poisson (Integrable)')
1260            # Wigner (chaotic systems)
1261            wigner = (np.pi * s / 2) * np.exp(-np.pi * s**2 / 4)
1262            ax.plot(s, wigner, 'r-', linewidth=2, label='Wigner (Chaotic)')
1263            ax.set_xlabel('Normalized spacing s')
1264            ax.set_ylabel('P(s)')
1265            ax.set_title('Level Spacing\nIntegrable vs Chaotic', fontweight='bold', fontsize=10)
1266            ax.legend(fontsize=8)
1267            ax.grid(True, alpha=0.3)

Complete visualization combining geometric and physical aspects

SymbolVisualizer2D(geometry: SymbolGeometry2D)
579    def __init__(self, geometry: SymbolGeometry2D):
580        self.geo = geometry
geo
def visualize_complete( self, x_range: Tuple[float, float], y_range: Tuple[float, float], xi_range: Tuple[float, float], eta_range: Tuple[float, float], geodesics_params: List[Tuple], E_range: Optional[Tuple[float, float]] = None, hbar: float = 1.0, resolution: int = 50) -> Tuple:
582    def visualize_complete(self,
583                          x_range: Tuple[float, float],
584                          y_range: Tuple[float, float],
585                          xi_range: Tuple[float, float],
586                          eta_range: Tuple[float, float],
587                          geodesics_params: List[Tuple],
588                          E_range: Optional[Tuple[float, float]] = None,
589                          hbar: float = 1.0,
590                          resolution: int = 50) -> Tuple:
591        """
592        Create a complete 18-panel visualization combining geometry and physics
593        Parameters
594        ----------
595        x_range, y_range : tuple
596            Configuration space domain
597        xi_range, eta_range : tuple
598            Momentum space domain
599        geodesics_params : list
600            Geodesic parameters: (x0, y0, xi0, eta0, t_max, color)
601        E_range : tuple, optional
602            Energy interval for spectral analysis
603        hbar : float
604            Reduced Planck constant
605        resolution : int
606            Grid resolution
607        Returns
608        -------
609        fig, geodesics, periodic_orbits, caustics
610        """
611        # Compute geodesics with caustic detection
612        geodesics = self._compute_geodesics(geodesics_params)
613        # Search for periodic orbits
614        periodic_orbits = []
615        if E_range:
616            energies = np.linspace(E_range[0], E_range[1], 5)
617            for E in energies:
618                orbits = self.geo.find_periodic_orbits_2d(
619                    E, x_range, y_range, xi_range, eta_range, n_attempts=20
620                )
621                periodic_orbits.extend(orbits)
622        # Detect caustic structures
623        caustics = []
624        if geodesics:
625            t_samples = np.linspace(0, geodesics[0].t[-1], 5)
626            for t in t_samples:
627                caustics.extend(self.geo.detect_caustic_structures(geodesics, t))
628        # Create full figure
629        fig = self._create_complete_figure(
630            E_range, x_range, y_range, xi_range, eta_range,
631            geodesics, periodic_orbits, caustics, hbar, resolution
632        )
633        return fig, geodesics, periodic_orbits, caustics

Create a complete 18-panel visualization combining geometry and physics

Parameters

x_range, y_range : tuple Configuration space domain xi_range, eta_range : tuple Momentum space domain geodesics_params : list Geodesic parameters: (x0, y0, xi0, eta0, t_max, color) E_range : tuple, optional Energy interval for spectral analysis hbar : float Reduced Planck constant resolution : int Grid resolution

Returns

fig, geodesics, periodic_orbits, caustics

class Utilities2D:
1356class Utilities2D:
1357    """Additional analysis tools for 2D systems"""
1358    @staticmethod
1359    def compute_winding_number(geo: Geodesic2D) -> float:
1360        """
1361        Compute winding number around origin
1362        """
1363        angles = np.arctan2(geo.y, geo.x)
1364        angles_unwrapped = np.unwrap(angles)
1365        winding = (angles_unwrapped[-1] - angles_unwrapped[0]) / (2 * np.pi)
1366        return winding
1367
1368    @staticmethod
1369    def compute_rotation_numbers(geo: Geodesic2D) -> Tuple[float, float]:
1370        """
1371        Compute rotation numbers (ω_x, ω_y)
1372        """
1373        theta_x = np.arctan2(geo.xi, geo.x)
1374        theta_y = np.arctan2(geo.eta, geo.y)
1375        theta_x = np.unwrap(theta_x)
1376        theta_y = np.unwrap(theta_y)
1377        omega_x = (theta_x[-1] - theta_x[0]) / (geo.t[-1] - geo.t[0])
1378        omega_y = (theta_y[-1] - theta_y[0]) / (geo.t[-1] - geo.t[0])
1379        return omega_x / (2*np.pi), omega_y / (2*np.pi)
1380    
1381    @staticmethod
1382    def detect_kam_tori(periodic_orbits: List[PeriodicOrbit2D],
1383                       tolerance: float = 0.1) -> Dict:
1384        """
1385        Detect KAM tori from periodic orbits
1386        """
1387        if not periodic_orbits:
1388            return {'n_tori': 0, 'tori': []}
1389        actions = np.array([orb.action for orb in periodic_orbits])
1390        # Cluster by action
1391        if len(actions) > 1:
1392            Z = linkage(actions.reshape(-1, 1), method='ward')
1393            clusters = fcluster(Z, t=tolerance, criterion='distance')
1394            n_tori = len(np.unique(clusters))
1395        else:
1396            n_tori = 1
1397            clusters = [1]
1398        # Analyze each torus
1399        tori = []
1400        for torus_id in np.unique(clusters):
1401            orbits_in_torus = [orb for i, orb in enumerate(periodic_orbits) 
1402                              if clusters[i] == torus_id]
1403            mean_action = np.mean([orb.action for orb in orbits_in_torus])
1404            mean_energy = np.mean([orb.energy for orb in orbits_in_torus])
1405            mean_period = np.mean([orb.period for orb in orbits_in_torus])
1406            stabilities = [orb.stability_1 for orb in orbits_in_torus]
1407            is_stable = np.mean(stabilities) < 0
1408            tori.append({
1409                'id': int(torus_id),
1410                'n_orbits': len(orbits_in_torus),
1411                'action': mean_action,
1412                'energy': mean_energy,
1413                'period': mean_period,
1414                'stable': is_stable
1415            })
1416        return {
1417            'n_tori': n_tori,
1418            'tori': tori
1419        }

Additional analysis tools for 2D systems

@staticmethod
def compute_winding_number(geo: src.geometry_2d.Geodesic2D) -> float:
1358    @staticmethod
1359    def compute_winding_number(geo: Geodesic2D) -> float:
1360        """
1361        Compute winding number around origin
1362        """
1363        angles = np.arctan2(geo.y, geo.x)
1364        angles_unwrapped = np.unwrap(angles)
1365        winding = (angles_unwrapped[-1] - angles_unwrapped[0]) / (2 * np.pi)
1366        return winding

Compute winding number around origin

@staticmethod
def compute_rotation_numbers(geo: src.geometry_2d.Geodesic2D) -> Tuple[float, float]:
1368    @staticmethod
1369    def compute_rotation_numbers(geo: Geodesic2D) -> Tuple[float, float]:
1370        """
1371        Compute rotation numbers (ω_x, ω_y)
1372        """
1373        theta_x = np.arctan2(geo.xi, geo.x)
1374        theta_y = np.arctan2(geo.eta, geo.y)
1375        theta_x = np.unwrap(theta_x)
1376        theta_y = np.unwrap(theta_y)
1377        omega_x = (theta_x[-1] - theta_x[0]) / (geo.t[-1] - geo.t[0])
1378        omega_y = (theta_y[-1] - theta_y[0]) / (geo.t[-1] - geo.t[0])
1379        return omega_x / (2*np.pi), omega_y / (2*np.pi)

Compute rotation numbers (ω_x, ω_y)

@staticmethod
def detect_kam_tori( periodic_orbits: List[src.geometry_2d.PeriodicOrbit2D], tolerance: float = 0.1) -> Dict:
1381    @staticmethod
1382    def detect_kam_tori(periodic_orbits: List[PeriodicOrbit2D],
1383                       tolerance: float = 0.1) -> Dict:
1384        """
1385        Detect KAM tori from periodic orbits
1386        """
1387        if not periodic_orbits:
1388            return {'n_tori': 0, 'tori': []}
1389        actions = np.array([orb.action for orb in periodic_orbits])
1390        # Cluster by action
1391        if len(actions) > 1:
1392            Z = linkage(actions.reshape(-1, 1), method='ward')
1393            clusters = fcluster(Z, t=tolerance, criterion='distance')
1394            n_tori = len(np.unique(clusters))
1395        else:
1396            n_tori = 1
1397            clusters = [1]
1398        # Analyze each torus
1399        tori = []
1400        for torus_id in np.unique(clusters):
1401            orbits_in_torus = [orb for i, orb in enumerate(periodic_orbits) 
1402                              if clusters[i] == torus_id]
1403            mean_action = np.mean([orb.action for orb in orbits_in_torus])
1404            mean_energy = np.mean([orb.energy for orb in orbits_in_torus])
1405            mean_period = np.mean([orb.period for orb in orbits_in_torus])
1406            stabilities = [orb.stability_1 for orb in orbits_in_torus]
1407            is_stable = np.mean(stabilities) < 0
1408            tori.append({
1409                'id': int(torus_id),
1410                'n_orbits': len(orbits_in_torus),
1411                'action': mean_action,
1412                'energy': mean_energy,
1413                'period': mean_period,
1414                'stable': is_stable
1415            })
1416        return {
1417            'n_tori': n_tori,
1418            'tori': tori
1419        }

Detect KAM tori from periodic orbits